From 13e5657049970cb3910738ec498dc0bb39134055 Mon Sep 17 00:00:00 2001 From: pressureless <190361902@qq.com> Date: Tue, 8 Jun 2021 21:26:51 -0400 Subject: [PATCH] Fix test error(#103) --- iheartla/la_parser/type_walker.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/iheartla/la_parser/type_walker.py b/iheartla/la_parser/type_walker.py index 8e2267e9..260889c0 100644 --- a/iheartla/la_parser/type_walker.py +++ b/iheartla/la_parser/type_walker.py @@ -535,14 +535,32 @@ def walk_WhereCondition(self, node, **kwargs): ir_node.type = type_node return ir_node + def get_single_factor(self, ir_node): + single = None + if ir_node.is_node(IRNodeType.Integer): + single = ir_node + elif ir_node.value.is_node(IRNodeType.Factor): + if ir_node.value.id: + single = ir_node.value.id + elif ir_node.value.num: + single = ir_node.value.num + return single def walk_MatrixType(self, node, **kwargs): ir_node = MatrixTypeNode(parse_info=node.parseinfo) id1_info = self.walk(node.id1, **kwargs) - ir_node.id1 = id1_info.ir + single_node = self.get_single_factor(id1_info.ir) + if single_node is not None: + ir_node.id1 = single_node + else: + ir_node.id1 = id1_info.ir id1 = id1_info.content id2_info = self.walk(node.id2, **kwargs) - ir_node.id2 = id2_info.ir + single_node = self.get_single_factor(id2_info.ir) + if single_node is not None: + ir_node.id2 = single_node + else: + ir_node.id2 = id1_info.ir id2 = id2_info.content element_type = '' if node.type: @@ -568,7 +586,11 @@ def walk_MatrixType(self, node, **kwargs): def walk_VectorType(self, node, **kwargs): ir_node = VectorTypeNode(parse_info=node.parseinfo) id1_info = self.walk(node.id1, **kwargs) - ir_node.id1 = id1_info.ir + single_node = self.get_single_factor(id1_info.ir) + if single_node is not None: + ir_node.id1 = single_node + else: + ir_node.id1 = id1_info.ir id1 = id1_info.content element_type = '' if node.type: @@ -2412,7 +2434,8 @@ def walk_ArithSubexpression(self, node, **kwargs): def walk_ArithAdd(self, node, **kwargs): left_info = self.walk(node.left, **kwargs) right_info = self.walk(node.right, **kwargs) - ret_type, need_cast = self.type_inference(TypeInferenceEnum.INF_ADD, left_info, right_info) + # ret_type, need_cast = self.type_inference(TypeInferenceEnum.INF_ADD, left_info, right_info) + ret_type = ScalarType(is_int=True) ret_info = NodeInfo(ret_type, symbols=left_info.symbols.union(right_info.symbols)) ir_node = AddNode(left_info.ir, right_info.ir, parse_info=node.parseinfo) ir_node.la_type = ret_type