diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 75f3a8dafb..0aba3081b6 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -180,6 +180,13 @@ def __init__( self.this_module = opset self.default_opset_ = default_opset + # States initialized by `init_function_translation` + self._outer = [] + self._current_fn = None + self._nextvar = 0 + self._used_vars = set() + self._locals: List[Dict[Any, Any]] = [{}] + @property def default_opset(self): if self.default_opset_ is None: @@ -222,14 +229,14 @@ def find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: def init_function_translation(self): """Initialize self for translating a new (top-level) function.""" - self.outer = [] - self.current_fn = None - self.nextvar = 0 - self.used_vars = set() - self.locals: List[Dict[Any, Any]] = [{}] + self._outer = [] + self._current_fn = None + self._nextvar = 0 + self._used_vars = set() + self._locals: List[Dict[Any, Any]] = [{}] def source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: - return sourceinfo.SourceInfo(node, self.source, self.current_fn.name) + return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) def message(self, node: ast.AST, error_msg: str) -> str: """Constructs an error message containing source information about an ast node.""" @@ -252,28 +259,28 @@ def enter_scope(self, name, parent_node): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - self.outer.insert(0, self.current_fn) - self.current_fn = self.ir_builder.new_function(name) - self.locals.insert(0, {}) - logger.debug("Converter:enter_scope:%d:node:%s", len(self.locals), type(parent_node)) + self._outer.insert(0, self._current_fn) + self._current_fn = self.ir_builder.new_function(name) + self._locals.insert(0, {}) + logger.debug("Converter:enter_scope:%d:node:%s", len(self._locals), type(parent_node)) def exit_scope(self): """Exit from a control-flow block (a loop body or if-then-else branch).""" - logger.debug("Converter:exit_scope:%d", len(self.locals)) - graph = self.current_fn - self.current_fn = self.outer.pop(0) - self.locals.pop(0) + logger.debug("Converter:exit_scope:%d", len(self._locals)) + graph = self._current_fn + self._current_fn = self._outer.pop(0) + self._locals.pop(0) return graph def current_scope(self): - return self.locals[0] + return self._locals[0] def bind(self, name, val): logger.debug("Converter:bind:%s", name) - self.locals[0][name] = val + self._locals[0][name] = val def lookup(self, name, info, raise_exception=True): - for scope in self.locals: + for scope in self._locals: if name in scope: return scope[name] if name in self.globals: @@ -285,10 +292,10 @@ def lookup(self, name, info, raise_exception=True): def generate_unique_name(self, candidate: str = "tmp") -> str: # TODO(justinchuby): Can we reduce the O complexity of this function? r = candidate - while r in self.used_vars: - r = f"{candidate}_{self.nextvar}" - self.nextvar = self.nextvar + 1 - self.used_vars.add(r) + while r in self._used_vars: + r = f"{candidate}_{self._nextvar}" + self._nextvar = self._nextvar + 1 + self._used_vars.add(r) return r def to_onnx_attr_ref(self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]): @@ -326,11 +333,11 @@ def py_var_to_onnx_var(self, py_var, info: sourceinfo.SourceInfo): return self.to_onnx_var(self.lookup(py_var, info), target=py_var, info=info) def emit_docstring(self, docstring): - self.ir_builder.add_docstring(self.current_fn, docstring) + self.ir_builder.add_docstring(self._current_fn, docstring) def emit(self, outputs, callee, inputs, attrs, sub_functions=None): self.ir_builder.add_stmt( - self.current_fn, + self._current_fn, outputs, callee, inputs, @@ -898,7 +905,7 @@ def translate_callee_expr(self, node) -> values.Op: # pylint: disable=R1710 function_name = node.id found = self.lookup(function_name, self.source_of(node), raise_exception=False) if isinstance(found, onnxscript.OnnxFunction): - self.current_fn.add_called_function(found) + self._current_fn.add_called_function(found) return found if isinstance(found, values.Op): return found @@ -1016,7 +1023,7 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.ir_builder.add_output(self.current_fn, return_var, t, self.source_of(stmt)) + self.ir_builder.add_output(self._current_fn, return_var, t, self.source_of(stmt)) return return_var val = stmt.value @@ -1123,7 +1130,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): self.enter_scope("loop_body", loop_stmt) o_loop_var = self.generate_unique_name(p_loop_var) self.ir_builder.add_input( - self.current_fn, + self._current_fn, o_loop_var, onnx_types.INT64, self.source_of(loop_stmt), @@ -1134,7 +1141,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): ) self.ir_builder.add_input( - self.current_fn, + self._current_fn, i_cond_var, onnx_types.BOOL, self.source_of(loop_stmt), @@ -1145,7 +1152,9 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self.eval_constant_expr(pv.annotation) typeinfo = None - self.ir_builder.add_input(self.current_fn, ov, typeinfo, self.source_of(loop_stmt)) + self.ir_builder.add_input( + self._current_fn, ov, typeinfo, self.source_of(loop_stmt) + ) self.bind( pv, values.Dynamic(ov, values.DynamicKind.Loop, self.source_of(loop_stmt)), @@ -1201,14 +1210,14 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): ) self.ir_builder.add_output( - self.current_fn, + self._current_fn, o_cond_out, onnx_types.BOOL, self.source_of(loop_stmt), ) for pv in loop_state_vars: ov = self.py_var_to_onnx_var(pv, self.source_of(loop_stmt)) - if ov not in self.current_fn.assigned_names: + if ov not in self._current_fn.assigned_names: # When converting the loop-body into a graph, we need to handle # identity assignments of the form "x = y" inside the loop body # specially if y represents a value computed outside the loop body. @@ -1218,7 +1227,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): # TODO: retrieve variable type for the annotation if any. typeinfo = None self.ir_builder.add_output( - self.current_fn, ov, typeinfo, self.source_of(loop_stmt) + self._current_fn, ov, typeinfo, self.source_of(loop_stmt) ) body = self.exit_scope() inputs = [o_loop_bound, o_true] + [ @@ -1245,19 +1254,19 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None): if pvar in self.current_scope(): pv_val = self.current_scope()[pvar] output = self.to_onnx_var(pv_val, pvar) - if output not in self.current_fn.assigned_names: + if output not in self._current_fn.assigned_names: # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self.emit_copy(output, pvar) self.ir_builder.add_output( - self.current_fn, + self._current_fn, output, pv_val.typeinfo, self.source_of(info_stmt), ) else: pv_val = None - for scope in self.locals: # TODO: skip current_scope + for scope in self._locals: # TODO: skip current_scope if pvar in scope: pv_val = scope[pvar] break @@ -1265,7 +1274,7 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None): self.fail( stmts[0], f"Variable {pvar} is not assigned a value along a conditional " - f"branch, known variables: {list(self.locals)}.", + f"branch, known variables: {list(self._locals)}.", ) # introduce a copy ovar = self.generate_unique_name(pvar) @@ -1278,7 +1287,7 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None): # TODO: retrieve the annotation if any. typeinfo = None self.ir_builder.add_output( - self.current_fn, ovar, typeinfo, self.source_of(info_stmt) + self._current_fn, ovar, typeinfo, self.source_of(info_stmt) ) graph = self.exit_scope() return graph.to_graph_and_functions() @@ -1294,15 +1303,15 @@ def translate_nested_function_def(self, fn: ast.FunctionDef): ] self.bind(fn.name, function_ir) # TODO: Does not yet handle nested functions within nested functions. - self.current_fn.add_nested_function(function_ir) + self._current_fn.add_nested_function(function_ir) - def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: - logger.debug("Converter:translate_function_def:%s", fn.name) + def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: + """Translate a function signature.""" args = fn.args if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg: warn(f"{fn.name}: Unsupported feature in function signature.") domain = self.this_module.domain - self.current_fn = self.ir_builder.new_function(fn.name, domain, True) + self._current_fn = self.ir_builder.new_function(fn.name, domain, True) for i, x in enumerate(args.args): arg_with_default_start_index = len(args.args) - len(args.defaults) if args.defaults and i >= arg_with_default_start_index: @@ -1324,15 +1333,15 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): self.ir_builder.add_attr_parameter( - self.current_fn, + self._current_fn, x.arg, ta.pytype_to_attrtype(typeinfo), default_value, ) self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x))) else: - self.ir_builder.add_input(self.current_fn, x.arg, typeinfo, self.source_of(x)) - self.used_vars.add(x.arg) + self.ir_builder.add_input(self._current_fn, x.arg, typeinfo, self.source_of(x)) + self._used_vars.add(x.arg) self.bind( x.arg, values.Dynamic(x.arg, values.DynamicKind.Input, self.source_of(x)), @@ -1352,9 +1361,16 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: self.returntype = None else: self.returntype = None + + return self._current_fn + + def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: + """Translate a function definition, including the signature and its body.""" + logger.debug("Converter:translate_function_def:%s", fn.name) + _ = self.translate_function_signature(fn) for i, s in enumerate(fn.body): self.translate_stmt(s, index_of_stmt=i) - return self.current_fn + return self._current_fn def top_level_stmt(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if isinstance(stmt, ast.FunctionDef):