Skip to content

Commit

Permalink
Fix test error(#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
pressureless committed Jun 9, 2021
1 parent a84a206 commit 13e5657
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions iheartla/la_parser/type_walker.py
Expand Up @@ -535,14 +535,32 @@ def walk_WhereCondition(self, node, **kwargs):
ir_node.type = type_node
return ir_node

def get_single_factor(self, ir_node):
single = None
if ir_node.is_node(IRNodeType.Integer):
single = ir_node
elif ir_node.value.is_node(IRNodeType.Factor):
if ir_node.value.id:
single = ir_node.value.id
elif ir_node.value.num:
single = ir_node.value.num
return single

def walk_MatrixType(self, node, **kwargs):
ir_node = MatrixTypeNode(parse_info=node.parseinfo)
id1_info = self.walk(node.id1, **kwargs)
ir_node.id1 = id1_info.ir
single_node = self.get_single_factor(id1_info.ir)
if single_node is not None:
ir_node.id1 = single_node
else:
ir_node.id1 = id1_info.ir
id1 = id1_info.content
id2_info = self.walk(node.id2, **kwargs)
ir_node.id2 = id2_info.ir
single_node = self.get_single_factor(id2_info.ir)
if single_node is not None:
ir_node.id2 = single_node
else:
ir_node.id2 = id1_info.ir
id2 = id2_info.content
element_type = ''
if node.type:
Expand All @@ -568,7 +586,11 @@ def walk_MatrixType(self, node, **kwargs):
def walk_VectorType(self, node, **kwargs):
ir_node = VectorTypeNode(parse_info=node.parseinfo)
id1_info = self.walk(node.id1, **kwargs)
ir_node.id1 = id1_info.ir
single_node = self.get_single_factor(id1_info.ir)
if single_node is not None:
ir_node.id1 = single_node
else:
ir_node.id1 = id1_info.ir
id1 = id1_info.content
element_type = ''
if node.type:
Expand Down Expand Up @@ -2412,7 +2434,8 @@ def walk_ArithSubexpression(self, node, **kwargs):
def walk_ArithAdd(self, node, **kwargs):
left_info = self.walk(node.left, **kwargs)
right_info = self.walk(node.right, **kwargs)
ret_type, need_cast = self.type_inference(TypeInferenceEnum.INF_ADD, left_info, right_info)
# ret_type, need_cast = self.type_inference(TypeInferenceEnum.INF_ADD, left_info, right_info)
ret_type = ScalarType(is_int=True)
ret_info = NodeInfo(ret_type, symbols=left_info.symbols.union(right_info.symbols))
ir_node = AddNode(left_info.ir, right_info.ir, parse_info=node.parseinfo)
ir_node.la_type = ret_type
Expand Down

0 comments on commit 13e5657

Please sign in to comment.