From d8e438d58b9feb38594768b3d49757a6ed5def8c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 28 Feb 2022 10:15:14 -0800 Subject: [PATCH 1/6] Initial version Signed-off-by: Ganesan Ramalingam --- onnxscript/__init__.py | 0 onnxscript/analysis.py | 99 +++++ onnxscript/convert.py | 6 + onnxscript/converter.py | 563 ++++++++++++++++++++++++ onnxscript/irbuilder.py | 146 ++++++ onnxscript/onnxscript/__init__.py | 35 ++ onnxscript/onnxscript/types/__init__.py | 69 +++ onnxscript/test/__init__.py | 0 onnxscript/test/attrref.py | 9 + onnxscript/test/converter_test.py | 46 ++ onnxscript/test/dropout.py | 10 + onnxscript/test/eg1.py | 27 ++ onnxscript/test/if.py | 31 ++ onnxscript/test/loop.py | 19 + onnxscript/test/multi.py | 7 + onnxscript/test/onnxfns.py | 39 ++ onnxscript/test/onnxmodels.py | 71 +++ onnxscript/test/renaming.py | 14 + onnxscript/type_annotation.py | 20 + onnxscript/utils.py | 23 + onnxscript/values.py | 59 +++ 21 files changed, 1293 insertions(+) create mode 100644 onnxscript/__init__.py create mode 100644 onnxscript/analysis.py create mode 100644 onnxscript/convert.py create mode 100644 onnxscript/converter.py create mode 100644 onnxscript/irbuilder.py create mode 100644 onnxscript/onnxscript/__init__.py create mode 100644 onnxscript/onnxscript/types/__init__.py create mode 100644 onnxscript/test/__init__.py create mode 100644 onnxscript/test/attrref.py create mode 100644 onnxscript/test/converter_test.py create mode 100644 onnxscript/test/dropout.py create mode 100644 onnxscript/test/eg1.py create mode 100644 onnxscript/test/if.py create mode 100644 onnxscript/test/loop.py create mode 100644 onnxscript/test/multi.py create mode 100644 onnxscript/test/onnxfns.py create mode 100644 onnxscript/test/onnxmodels.py create mode 100644 onnxscript/test/renaming.py create mode 100644 onnxscript/type_annotation.py create mode 100644 onnxscript/utils.py create mode 100644 onnxscript/values.py diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py new file mode 100644 index 0000000000..854eb00d5a --- /dev/null +++ b/onnxscript/analysis.py @@ -0,0 +1,99 @@ +import ast + +from numpy import isin + +def used_vars(expr): + ''' Return set of all variables used with an expression.''' + if (isinstance(expr, ast.Name)): + return set([expr.id]) + result = set() + for c in ast.iter_child_nodes(expr): + result = result | used_vars(c) + return result + +def local_defs(lhs): + '''Utility function to return set of assigned/defined variables in the lhs of an assignment statement.''' + def get_id(e): + assert isinstance(e, ast.Name), "Only simple assignments supported." + return e.id + if (isinstance(lhs, ast.Tuple)): + return set([get_id(x) for x in lhs.elts]) + else: + return set([get_id(lhs)]) + +def defs(stmt): + ''' + Return the set of all variables that may be defined (assigned to) in an + execution of input stmt. + ''' + def block_defs(block): + result = set() + for s in block: + result = result | defs(s) + return result + if (isinstance(stmt, ast.Assign)): + return local_defs(stmt.targets[0]) + elif (isinstance(stmt, ast.Return)): + return set() + elif (isinstance(stmt, ast.If)): + return block_defs(stmt.body) | block_defs(stmt.orelse) + elif isinstance(stmt, list): + return block_defs(stmt) + else: + raise ValueError("Unsupported statement type: " + type(stmt).__name__) + + +def do_liveness_analysis(fun): + ''' + Perform liveness analysis of the given function-ast. The results of the + analysis are stored directly with each statement-ast `s` as attributes `s.live_in` + and `s.live_out`. + ''' + def visit (stmt, live_out): + def visitBlock(block, live_out): + for s in reversed(block): + live_out = visit(s, live_out) + return live_out + if (isinstance(stmt, ast.Assign)): + return (live_out.difference(local_defs(stmt.targets[0]))) | used_vars(stmt.value) + elif (isinstance(stmt, ast.Return)): + return used_vars(stmt.value) + elif (isinstance(stmt, ast.If)): + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return (live1 | live2) | used_vars(stmt.test) + elif isinstance(stmt, ast.For): + return live_out # TODO + else: + raise ValueError("Unsupported statement type: " + type(stmt).__name__) + assert type(fun) == ast.FunctionDef + live = set() + for s in reversed(fun.body): + s.live_out = live + live = visit(s, live) + s.live_in = live + # print(ast.dump(s)) + # print("Live-In = ", live) + +def exposed_uses(stmts): + ''' + Return the set of variables that are used before being defined by given block. + ''' + def visitBlock(block, live_out): + for stmt in reversed(block): + live_out = visit(stmt, live_out) + return live_out + + def visit (stmt, live_out): + if (isinstance(stmt, ast.Assign)): + return (live_out.difference(local_defs(stmt.targets[0]))) | used_vars(stmt.value) + elif (isinstance(stmt, ast.Return)): + return used_vars(stmt.value) + elif (isinstance(stmt, ast.If)): + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return (live1 | live2) | used_vars(stmt.test) + else: + raise ValueError("Unsupported statement type: " + type(stmt).__name__) + + return visitBlock(stmts, set()) \ No newline at end of file diff --git a/onnxscript/convert.py b/onnxscript/convert.py new file mode 100644 index 0000000000..bd9067c19e --- /dev/null +++ b/onnxscript/convert.py @@ -0,0 +1,6 @@ +import sys +from converter import convert + +if __name__ == '__main__': + for i in range(1, len(sys.argv)): + convert(sys.argv[i]) \ No newline at end of file diff --git a/onnxscript/converter.py b/onnxscript/converter.py new file mode 100644 index 0000000000..62f72ee845 --- /dev/null +++ b/onnxscript/converter.py @@ -0,0 +1,563 @@ +from numpy import isin +import onnxscript +import onnxscript.types as types + +import os +import inspect +import ast +from ast import * + +import irbuilder +from irbuilder import IRBuilder +import analysis +import type_annotation as ta +import values +from values import Value, ConstValue, AttrRef, Dynamic, Op + +import onnx +import onnx.helper as helper + +print_flag = True + +# Python-to-IR converter: + +def not_allowed(construct): + return construct + "not supported." + +class TranslationError(Exception): + def __init__(self, *args: object) -> None: + super().__init__(*args) + +def warn(msg): + print ("Warning: " + msg) + +def fail(msg): + raise TranslationError(msg) + +def fail_if(cond, msg): + if cond: + raise TranslationError(msg) + +def ignore (cond, msg): + if cond: + warn (msg) + +def pyvalue_to_tensor(tensor_name : str, pyvalue): + if isinstance(pyvalue, bool): + return helper.make_tensor(tensor_name, onnx.TensorProto.BOOL, [], [int(pyvalue)]) + elif isinstance(pyvalue, int): + return helper.make_tensor(tensor_name, onnx.TensorProto.INT64, [], [pyvalue]) + elif isinstance(pyvalue, float): + return helper.make_tensor(tensor_name, onnx.TensorProto.FLOAT, [], [pyvalue]) + else: + fail("Unimplemented") + +# map from python operators to ONNX ops +primop_map = { + ast.Add: "Add", + ast.Sub: "Sub", + ast.Mult: "Mul", + ast.Div: "Div", + ast.USub: "Neg", + ast.Lt: "Less", + ast.Gt: "Greater", + ast.LtE: "LessOrEqual", + ast.GtE : "GreaterOrEqual", + ast.MatMult : "MatMul", + ast.Mod : "Mod", + ast.Pow : "Pow" +} + +class Converter: + def __init__(self, ir_builder = IRBuilder()): + self.ir_builder = ir_builder + self.known_modules = { 'onnxscript' : onnxscript, 'onnxscript.types' : types, 'onnx.opset15' : values.opset15 } + self.globals = { "int" : int, "float" : float, "str" : str, "onnx" : values.opset15, "Onnx" : values.opset15 } # 'os' : onnxscript + self.pure_modules = ["onnxscript"] + self.default_type = types.FLOAT[...] + + def initFunctionTranslation(self): + self.outer = [] + self.current_fn = None + self.nextvar = 0 + self.used_vars = set() + self.locals = [{}] + + def enterScope(self, name): + self.outer.insert(0,self.current_fn) + self.current_fn = self.ir_builder.newFunction(name) + self.locals.insert(0, {}) + + def exitScope(self): + graph = self.current_fn + self.current_fn = self.outer[0] + self.outer.pop(0) + self.locals.pop(0) + return graph + + def currentScope(self): + return self.locals[0] + + def bind(self, name, val): + self.locals[0][name] = val + + def lookup(self, name): + for scope in self.locals: + if (name in scope): return scope[name] + if (name in self.globals): return self.globals[name] + raise ValueError("Unbound name: " + name) + + def generateUniqueName(self, candidate = "tmp"): + r = candidate + while (r in self.used_vars): + r = candidate + "_" + str(self.nextvar) + self.nextvar = self.nextvar + 1 + self.used_vars.add(r) + return r + + def to_onnx_attr_ref(self, val : AttrRef): + pytype = val.type + attrname = "value_float" if (pytype is float) else ("value_int" if (pytype is int) else "value_string") + return self.ir_builder.attr_ref(attrname, val.value, pytype) + + def to_onnx_var(self, val, target = None): + if (isinstance(val, AttrRef)): + # promote attribute to value + result = self.generateUniqueName(target if target else "tmp") + attr = self.to_onnx_attr_ref(val) + self.emit ([result], Op("", "Constant"), [], [attr]) + return result + elif (isinstance(val, ConstValue) and isinstance(val.value, float)): # TODO + result = self.generateUniqueName(target if target else "tmp") + return self.emitConst(val.value, result) + elif isinstance(val, Dynamic): + return val.value + else: + fail(f"Cannot convert to onnx variable") + + def py_var_to_onnx_var(self, py_var): return self.to_onnx_var(self.lookup(py_var)) + + def emit(self, outputs, callee, inputs, attrs): + self.ir_builder.addStmt(self.current_fn, outputs, callee.opset, callee.opname, inputs, attrs) + + def emit2(self, outputs, callee, inputs, attrs): + def rename(x): + r = self.generateUniqueName(x) + self.bind(x, Dynamic(r)) + return r + onnx_inputs = inputs # [ self.to_onnx_var(self.lookup(pvar)) for pvar in inputs ] + onnx_outputs = [ rename(x) for x in outputs ] + self.emit(onnx_outputs, Op("", callee), onnx_inputs, attrs) + + def emitConst (self, pyvalue, suggested_name): + ovar = self.generateUniqueName(suggested_name) + tensor = pyvalue_to_tensor (ovar, pyvalue) + attr = self.ir_builder.attr("value", tensor) + self.emit([ovar], Op("", "Constant"), [], [attr]) + return ovar + + def isPureModule(self, m): + return (m in self.pure_modules) + + def isConstantExpr (self, node): + if (isinstance(node, ast.Name)): + val = self.lookup(node.id) + return isinstance(val, ConstValue) and self.isPureModule(val.value) + if isinstance(node, (ast.Call, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Num, ast.Str, ast.Attribute)): + return all([self.isConstantExpr(c) for c in ast.iter_child_nodes(node)]) + return False + + def evalConstantExpr(self, node): + # TODO: assert (self.isConstantExpr(node)) + locals = {} # TODO + return eval(compile(ast.Expression(node), filename="", mode="eval"), self.globals, locals) + + def eval_attr(self, node): + if (isinstance(node, ast.Num)): + return node.n + elif (isinstance(node, ast.Str)): + return node.s + elif (isinstance(node, ast.NameConstant)): + if (node.value == True): + return 1 + elif (node.value == False): + return 0 + else: + raise ValueError("Unsupported NameConstant attribute : " + str(node.value)) + elif (isinstance(node, ast.List)): + return [self.eval_attr(x) for x in node.elts] + elif (isinstance(node, (ast.Call, ast.Attribute))): + return self.evalConstantExpr(node) + else: + raise ValueError("Unsupported attribute type: " + type(node).__name__) + + def translateAttr(self, attr_name, node): + if (isinstance(node, ast.Name)): + val = self.lookup(node.id) + if (isinstance(val, AttrRef)): + return self.to_onnx_attr_ref(val) + else: + # TODO: lookup value; if func.def., compile it to Graph; if a constant; etc. + fail("Unimplemented attribute construct") + return self.ir_builder.attr(attr_name, self.eval_attr(node)) + + # Expression-translation generates "IR statements/nodes" that compute the value of + # the expression into a target-variable, and returns the variable that is assigned this value. + def translateExpr(self, node, target="tmp"): + if (isinstance(node, ast.Call)): + r = self.translateCall(node) + elif (isinstance(node, ast.BinOp)): + r = self.translateBinOp (node) + elif (isinstance(node, ast.UnaryOp)): + r = self.translateUnaryOp (node) + elif (isinstance(node, ast.Compare)): + r = self.translateCompare (node) + elif (isinstance(node, ast.Name)): + r = self.translateName (node) + elif (isinstance(node, ast.Num)): + r = self.emitConst(node.n, target) + elif isinstance(node, ast.NameConstant): + r = self.emitConst(node.value, target) + # elif (isinstance(node, ast.Attribute)): + # r = self.translateAttribute (node) + else: + raise ValueError("Unsupported expression type: " + type(node).__name__) + if (isinstance(r, tuple)): + if isinstance(target, str): + result = self.generateUniqueName(target) + callee, args, attrs = r + self.emit([result], callee, args, attrs) + return result + else: + assert isinstance(target, list) + results = [self.generateUniqueName(x) for x in target] + callee, args, attrs = r + self.emit(results, callee, args, attrs) + return results + return r + + def translateCall(self, node): + # TODO: for now, we map named arguments to attributes, and positional arguments to inputs. + callee = self.translateCallee(node.func) + args = [self.translateExpr(x) for x in node.args] + attrs = [self.translateAttr(x.arg, x.value) for x in node.keywords] + return (callee, args, attrs) + + def translateBinOp(self, node): + op = type(node.op) + assert (op in primop_map) + opname = primop_map[op] + left = self.translateExpr(node.left) + right = self.translateExpr(node.right) + return (Op("", opname), [left, right], []) + + def translateUnaryOp(self, node): + op = type(node.op) + assert (op in primop_map) + opname = primop_map[op] + operand = self.translateExpr(node.operand) + return (Op("", opname), [operand], []) + + def translateCompare(self, node): + assert (len(node.ops) == 1) # TODO: handle multiple comparisons in one expression + assert (len(node.comparators) == 1) + op = type(node.ops[0]) + assert (op in primop_map) + opname = primop_map[op] + left = self.translateExpr(node.left) + right = self.translateExpr(node.comparators[0]) + return (Op("", opname), [left, right], []) + + # TODO: returns??? + def translateName(self, node): + return self.py_var_to_onnx_var(node.id) + + def translateNum(self, node): + return self.emitConst(node.n, "Const") + + def translateModule(self, node): + if isinstance(node, ast.Name): + try: + val = self.lookup(node.id) + if isinstance(val, ConstValue): #TODO + val = val.value + if isinstance(val, values.Opset): + return val + fail(f"{node.id} has value of type {type(node.id)} and used as module") + except: + warn(f"Unknown module name {node.id}.") + return values.Opset(node.id, 1) + elif isinstance (node, ast.Attribute): + fail("Nested module unimplemented") + else: + fail("Invalid module.") + + def translateCallee(self, node): + if isinstance(node, ast.Attribute): + module = self.translateModule(node.value) + opname = node.attr + if (opname not in module): warn (f"{opname} is not a known op in {str(module)}") + return Op(module, node.attr) + elif isinstance(node, ast.Name): + try: + val = self.lookup(node.id) + except: + default_opset = values.opset15 + if (node.id not in default_opset): + warn(f"Unknown function name {node.id}.") + return Op(default_opset, node.id) + else: + fail ("Invalid callee") + + # Statement translation: A single Python statement is mapped into a sequence of IR statements. + + def translateStmt(self, node): + if (isinstance(node, ast.Assign)): + self.translateAssign(node) + elif (isinstance(node, ast.Return)): + self.translateReturn(node) + elif (isinstance(node, ast.If)): + self.translateIf(node) + elif isinstance(node, ast.For): + self.translateFor(node) + else: + raise ValueError("Unsupported statement type: " + type(node).__name__) + + def translateAssign(self, node: Assign): + def assign(lhs, rhs): + if (isinstance(lhs, ast.Name)): + lhs = lhs.id + if (self.isConstantExpr(rhs)): + self.bind(lhs, ConstValue(self.evalConstantExpr(rhs))) + else: + t = self.translateExpr(rhs, lhs) + self.bind(lhs, Dynamic(t)) + elif isinstance(lhs, ast.Tuple): + def id(x): + assert isinstance(x, ast.Name) + return x.id + ids = [id(x) for x in lhs.elts] + onnxids = self.translateExpr(rhs, ids) + for x, y in zip(ids, onnxids): + self.bind(x, Dynamic(y)) + else: + fail("Unsupported construct in LHS of assignment.") + assert (len(node.targets) == 1), "Multi-assignment not supported." + lhs = node.targets[0] + rhs = node.value + if (isinstance(rhs, ast.Tuple)): + assert isinstance(lhs, ast.Tuple) + assert len(lhs.elts) == len(rhs.elts), "Expected same number of elements on lhs and rhs of assignments." + for l,r in zip(lhs.elts, rhs.elts): + assign(l, r) + else: + assign(lhs, rhs) + + def translateReturn(self, node): + def ret(exp, suffix=""): + ovar = self.translateExpr(exp, "return_val" + suffix) + # if hasattr(self, returntype) and self.num_outputs < len(self.returntype): + try: + t = self.returntype[self.num_outputs] + except Exception as e: + t = self.default_type + self.ir_builder.addOutput(self.current_fn, ovar, t) + self.num_outputs += 1 + return ovar + + val = node.value + assert (val != None), "Return statement without return-value not supported." + if (isinstance(val, ast.Tuple)): + return [ret(exp,str(i)) for i,exp in enumerate(val.elts)] + else: + return ret(val) + + def translateIf(self, node): + live_defs = list(node.live_out.intersection(analysis.defs(node))) + # print(live_defs) + test = self.translateExpr(node.test, "cond") + thenGraph = self.translateBlock(node.body, "thenGraph", live_defs) + thenAttr = self.ir_builder.attr("then_branch", thenGraph) + elseGraph = self.translateBlock(node.orelse, "elseGraph", live_defs) + elseAttr = self.ir_builder.attr("else_branch", elseGraph) + def rename(x): + r = self.generateUniqueName(x) + self.bind(x, Dynamic(r)) + return r + renamed = [ rename(x) for x in live_defs ] + self.emit(renamed, Op("", "If"), [test], [thenAttr, elseAttr]) + + def translateFor(self, for_stmt: ast.For): + # loop-variable + assert isinstance(for_stmt.target, ast.Name), "For loop target must be a single variable." + p_loop_var = for_stmt.target.id + # iter + iter = for_stmt.iter + assert isinstance(iter, ast.Call), "Loop bound not a call." + assert isinstance(iter.func, ast.Name), "Unsupported loop bound." + assert iter.func.id == "range", "Unsupported loop bound." + assert iter.args and len(iter.args) == 1, "Unsupported loop bound." + assert not iter.keywords, "Unsupported loop bound." + o_loop_bound = self.translateExpr(iter.args[0], "loop_bound") + # analyze loop body + exposed_uses = analysis.exposed_uses(for_stmt.body) + vars_def_in_loop = analysis.defs(for_stmt.body) + loop_state_vars = vars_def_in_loop.intersection(exposed_uses | for_stmt.live_out) + scan_outputs = set() # TODO + outputs = list(loop_state_vars | scan_outputs) + + # loop-condition: + o_true = self.emitConst(True, "true") + # o_loop_bound = self.emitConst(3, "loop_bound") + + # build loop_body + self.enterScope("loop_body") + o_loop_var = self.generateUniqueName(p_loop_var) + self.ir_builder.addInput(self.current_fn, o_loop_var, types.INT64) + self.bind(p_loop_var, Dynamic(o_loop_var)) + o_cond_var = self.generateUniqueName("cond_in") + self.ir_builder.addInput(self.current_fn, o_cond_var, types.BOOL) + for pv in loop_state_vars: + ov = self.generateUniqueName(pv) + self.ir_builder.addInput(self.current_fn, ov, self.default_type) + self.bind(pv, Dynamic(ov)) + for s in for_stmt.body: + self.translateStmt(s) + o_cond_out = self.generateUniqueName("cond_out") + self.emit([o_cond_out], Op("", "Identity"), [o_cond_var], []) + self.ir_builder.addOutput(self.current_fn, o_cond_out, types.BOOL) + for pv in loop_state_vars: + ov = self.py_var_to_onnx_var(pv) + self.ir_builder.addOutput(self.current_fn, ov, self.default_type) # TODO: type + body = self.exitScope() + # if (print_flag): + # print("Generated loop body:") + # body.print() + + inputs = [o_loop_bound, o_true] + [self.py_var_to_onnx_var(pv) for pv in loop_state_vars] + attrs = [self.ir_builder.attr("body", body.toGraph())] + self.emit2(outputs, "Loop", inputs, attrs) + + # Translation of a statement-block to GraphProto attribute + def translateBlock(self, stmts, name, live_defs): + self.enterScope(name) + for s in stmts: + self.translateStmt(s) + for pvar in live_defs: + if (pvar in self.currentScope()): + pv_val = self.currentScope()[pvar] + output = self.to_onnx_var(pv_val, pvar) + self.ir_builder.addOutput(self.current_fn, output, self.default_type) # TODO: need type! + else: + pv_val = None + for scope in self.locals: # TODO: skip currentScope + if (pvar in scope): + pv_val = scope[pvar] + break + if (pv_val is None): + fail (f"Variable {pvar} is not assigned a value along a conditional branch.") + # introduce a copy + ovar = self.generateUniqueName(pvar) + self.emit([ovar], Op("", "Identity"), [self.to_onnx_var(pv_val, pvar)], []) + self.ir_builder.addOutput(self.current_fn, ovar, self.default_type) # TODO: need type! + graph = self.exitScope() + # if print_flag: + # print ("Generated block") + # graph.print() + return graph.toGraph() + + def convert_FunctionDef(self, node): + args = node.args + if (args.defaults): + warn (f"{node.name}: Default values not yet implemented.") + if (args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg): + warn (f"{node.name}: Unsupported feature in function signature.") + self.current_fn = self.ir_builder.newFunction(node.name) + for x in args.args: + if x.annotation: + typeinfo = self.evalConstantExpr(x.annotation) + else: + typeinfo = self.default_type + assert ta.is_valid(typeinfo) + if (ta.is_attr(typeinfo)): + self.ir_builder.addAttr(self.current_fn, x.arg, typeinfo) + self.bind(x.arg, AttrRef(x.arg, typeinfo)) + else: + self.ir_builder.addInput(self.current_fn, x.arg, typeinfo) + self.bind(x.arg, Dynamic(x.arg)) + if node.returns: + returntype = self.evalConstantExpr(node.returns) + if isinstance(returntype, tuple): + assert all([ta.is_valid(t) for t in returntype]) + self.returntype = returntype + else: + assert ta.is_valid(returntype) + self.returntype = (returntype,) + else: + self.returntype = None + self.num_outputs = 0 + for s in node.body: + self.translateStmt(s) + if self.returntype is not None: + assert (self.num_outputs == len(self.returntype)), "Mismatch in number of return values and types" + return self.current_fn + + def do_import(self, alias): + if print_flag: + print (f"Importing {alias.name} as {alias.asname}") + fail_if (alias.name not in self.known_modules, f"Import: unsupported module {alias.name}") + asname = alias.asname if alias.asname else alias.name + self.globals[asname] = self.known_modules[alias.name] + + def top_level_stmt (self, stmt): + if isinstance(stmt, ast.FunctionDef): + self.initFunctionTranslation() + analysis.do_liveness_analysis(stmt) + fn_ir = self.convert_FunctionDef(stmt) + if print_flag: + print("============== OUTPUT =============") + fn_ir.print() + print() + return fn_ir + elif isinstance(stmt, ast.Import): + for alias in stmt.names: + self.do_import(alias) + elif isinstance(stmt, ast.ImportFrom): + fail_if(stmt.module is None, "Import: module unspecified.") + fail_if (stmt.module not in self.known_modules, f"Import: unsupported module {stmt.module}") + module = self.known_modules[stmt.module] + for alias in stmt.names: + asname = alias.asname if alias.asname else alias.name + self.globals[asname] = getattr(module, alias.name) + else: + raise ValueError("Unsupported top-level statement type: " + type(stmt).__name__) + + def convert_source(self, src): + if print_flag: + print("============== INPUT =============") + print(src) + print() + module = ast.parse(src) + assert type(module) == ast.Module + converted = [self.top_level_stmt(d) for d in module.body] + return [x for x in converted if x is not None] + + def convert_file(self, filename): + src = open(filename).read() + return self.convert_source(src) + + def convert(self, f): + if (isinstance(f, str)): + if (os.path.exists(f)): + return self.convert_file(f) + else: + return self.convert_source(f) + elif (inspect.isfunction(f)): + src = inspect.getsource(f) + return self.convert_source(src) + else: + fail("Unknown type of input to converter.") + +def convert (script): + converter = Converter() + converter.convert(script) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py new file mode 100644 index 0000000000..60b88ecaa0 --- /dev/null +++ b/onnxscript/irbuilder.py @@ -0,0 +1,146 @@ +import onnx +import onnx.helper as helper +import type_annotation as ta + +# A simple IR (Function, Stmt, Attr, Var): + +def format(list, prefix, sep, suffix): + return prefix + sep.join([str(x) for x in list]) + suffix + +class Type: + def __init__(self) -> None: + # TODO + tp = onnx.TypeProto() + tp.tensor_type.elem_type = onnx.TensorProto.FLOAT + self.onnx_type = tp + # helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [10]) + + def toTypeProto(self): + return self.onnx_type + + def __str__(self) -> str: + return "SomeType" + +class Var: + def __init__(self, varname, type = None) -> None: + self.name = varname + self.type = type + + def __str__(self): + return self.name + + def typedstr(self): + return self.name + " : " + str(self.type) + + def toValueInfo(self): + tp = self.type.toTypeProto() + # if (not tp.tensor_type.HasField('shape')): + # # TODO: temporary patch to export a function as a graph + # tp = helper.make_tensor_type_proto(tp.tensor_type.elem_type, [10]) + return helper.make_value_info(self.name, tp) + +class Attr: + def __init__(self, attrproto) -> None: + self.attr_proto = attrproto + + def __str__(self): + if (self.attr_proto.HasField("ref_attr_name")): + return self.attr_proto.name + " = @" + self.attr_proto.ref_attr_name + return helper.printable_attribute(self.attr_proto) # self.name + " = " + self.value + +class Stmt: + def __init__(self, result, module, opname, args, attrs) -> None: + self.result = result + self.module = module + self.opname = opname + self.args = args + self.attrs = attrs + + def __str__(self): + if (isinstance(self.result, str)): + print('Ooops') + lhs = ", ".join(self.result) + attrs = "" + if (self.attrs): + attrs = format(self.attrs, "<", ", ", ">") + + args = format (self.args, "(", ", ", ")") + module = str(self.module) + callee = module + "." + self.opname if (module != '') else self.opname + return (lhs + " = " + self.opname + " " + attrs + args) + + def print(self): + print (str(self)) + + def toNode(self): + n = helper.make_node(self.opname, + [str(x) for x in self.args], + [str(x) for x in self.result]) + for a in self.attrs: + n.attribute.append(a.attr_proto) + return n + +class Function: + def __init__(self, name) -> None: + self.name = name + self.inputs = [] + self.outputs = [] + self.stmts = [] + self.attrs = [] + + def __str__(self): + attrs = format (self.attrs, "<", ", ", ">") if self.attrs else "" + inputs = format ([x.typedstr() for x in self.inputs], "(", ", ", ")") + outputs = format ([x.typedstr() for x in self.outputs], "(", ", ", ")") + stmts = format (self.stmts, "\n{\n ", "\n ", "\n}\n") + return (self.name + " " + attrs + inputs + " => " + outputs + stmts) + + def print(self): + print (str(self)) + for s in self.stmts: + for attr in s.attrs: + if attr.attr_proto.HasField("g"): + print(helper.printable_graph(attr.attr_proto.g)) + + + def toGraph(self): + return helper.make_graph([s.toNode() for s in self.stmts], + self.name, + [x.toValueInfo() for x in self.inputs], + [y.toValueInfo() for y in self.outputs] + ) + +# IRBuilder: abstracts out details of the IR in the python-to-IR converter + +class IRBuilder: + def newFunction(self, name): + return Function(name) + + def addStmt(self, fn, results, module, opname, args, attrs): + s = Stmt(results, module, opname, args, attrs) + fn.stmts.append(s) + + def addInput(self, fn, varname, type): + v = Var(varname, type) + fn.inputs.append(v) + + def addAttr(self, fn, varname, type): + v = Var(varname, type) + fn.attrs.append(v) + + def addOutput(self, fn, varname, type): + v = Var(varname, type) + fn.outputs.append(v) + + def attr(self, attrname, attrval): + if (isinstance(attrval, Function)): + attrval = str(attrval) # TODO + return Attr(helper.make_attribute(attrname, attrval)) + + def attr_ref(self, attrname, refname, pytype): + a = onnx.AttributeProto() + a.name = attrname + a.ref_attr_name = refname + a.type = ta.pytype_to_attrtype_map[pytype] # onnx.AttributeProto.FLOAT + return Attr(a) + # TODO: attr_type? diff --git a/onnxscript/onnxscript/__init__.py b/onnxscript/onnxscript/__init__.py new file mode 100644 index 0000000000..6fef22f00f --- /dev/null +++ b/onnxscript/onnxscript/__init__.py @@ -0,0 +1,35 @@ +from argparse import ArgumentError +from typing import Union +from onnx.helper import make_tensor +import onnx +import onnx.helper + +# ElementType: an enumeration encoding the allowed element types in a tensor +# Corresponds to TensorProto::DataType +class ElementType: + UNDEFINED = 0 + + FLOAT = 1 # float + UINT8 = 2 # uint8_t + INT8 = 3 # int8_t + UINT16 = 4 # uint16_t + INT16 = 5 # int16_t + INT32 = 6 # int32_t + INT64 = 7 # int64_t + STRING = 8 # string + BOOL = 9 # bool + + # IEEE754 half-precision floating-point format (16 bits wide). + # This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10 + + DOUBLE = 11 + UINT32 = 12 + UINT64 = 13 + COMPLEX64 = 14 # complex with float32 real and imaginary components + COMPLEX128 = 15 # complex with float64 real and imaginary components + + # Non-IEEE floating-point format based on IEEE754 single-precision + # floating-point number truncated to 16 bits. + # This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16 diff --git a/onnxscript/onnxscript/types/__init__.py b/onnxscript/onnxscript/types/__init__.py new file mode 100644 index 0000000000..ef99ee1e2b --- /dev/null +++ b/onnxscript/onnxscript/types/__init__.py @@ -0,0 +1,69 @@ +import onnx +import onnx.helper + +class Tensor: + # Reference implementation placeholder + # represents a generic ONNX tensor type + def __init__(self, dtype = onnx.TensorProto.UNDEFINED, shape = None) -> None: + self.dtype = dtype + self.shape = shape + + def __str__(self) -> str: + shapestr = str(self.shape) if self.shape else "[...]" + return onnx.TensorProto.DataType.Name(self.dtype) + shapestr + + def toTypeProto(self): + # TODO: handle None + return onnx.helper.make_tensor_type_proto(self.dtype, self.shape) + +# Utilities used to create parametrized type-annotations for tensors. +# Example type annotations: +# x : FLOAT (a scalar-tensor of rank 0) +# x : FLOAT[...] (a tensor of unknown rank) +# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names) +# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions) + +class ParametricTensor: + def __init__(self, dtype) -> None: + self.dtype = dtype + + def __getitem__ (self, shape): + def mk_dim(dim): + r = onnx.TensorShapeProto.Dimension() + if (isinstance(dim, int)): + r.dim_value = dim + elif (isinstance(dim, str)): + r.dim_param = dim + elif (dim is not None): + raise TypeError("Invalid dimension") + return r + if (isinstance(shape, tuple)): + s = shape + elif (shape == Ellipsis): + s = None + else: + s = [shape] + return Tensor(self.dtype, s) + + def toTypeProto(self): + return onnx.helper.make_tensor_type_proto(self.dtype, ()) + + def __str__(self) -> str: + return onnx.TensorProto.DataType.Name(self.dtype) + +FLOAT = ParametricTensor (onnx.TensorProto.FLOAT) +UINT8 = ParametricTensor (onnx.TensorProto.UINT8) +INT8 = ParametricTensor (onnx.TensorProto.INT8) +UINT16 = ParametricTensor (onnx.TensorProto.UINT16) +INT16 = ParametricTensor (onnx.TensorProto.INT16) +INT32 = ParametricTensor (onnx.TensorProto.INT32) +INT64 = ParametricTensor (onnx.TensorProto.INT64) +STRING = ParametricTensor (onnx.TensorProto.STRING) +BOOL = ParametricTensor (onnx.TensorProto.BOOL) +FLOAT16 = ParametricTensor (onnx.TensorProto.FLOAT16) +DOUBLE = ParametricTensor (onnx.TensorProto.DOUBLE) +UINT32 = ParametricTensor (onnx.TensorProto.UINT32) +UINT64 = ParametricTensor (onnx.TensorProto.UINT64) +COMPLEX64 = ParametricTensor (onnx.TensorProto.COMPLEX64) +COMPLEX128 = ParametricTensor (onnx.TensorProto.COMPLEX128) +BFLOAT16 = ParametricTensor (onnx.TensorProto.BFLOAT16) \ No newline at end of file diff --git a/onnxscript/test/__init__.py b/onnxscript/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/test/attrref.py b/onnxscript/test/attrref.py new file mode 100644 index 0000000000..722d85f055 --- /dev/null +++ b/onnxscript/test/attrref.py @@ -0,0 +1,9 @@ + +def float_attr_ref_test (X, alpha: float) : + return onnx.Foo(X, alpha) + +def int_attr_ref_test (X, alpha: int) : + return onnx.Foo(X, alpha) + +def str_attr_ref_test (X, alpha: str) : + return onnx.Foo(X, alpha) \ No newline at end of file diff --git a/onnxscript/test/converter_test.py b/onnxscript/test/converter_test.py new file mode 100644 index 0000000000..976c806131 --- /dev/null +++ b/onnxscript/test/converter_test.py @@ -0,0 +1,46 @@ +import unittest +import os +import onnx +from converter import Converter + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + +class TestConverter(unittest.TestCase): + def _convert (self, script): + converter = Converter() + converter.convert(script) + + def _convert_and_save (self, script): + converter = Converter() + fnlist = converter.convert(script) + TEST_OUTPUT_DIR = os.path.join(CURRENT_DIR, "testoutputs") + if (not os.path.exists(TEST_OUTPUT_DIR)): + os.makedirs(TEST_OUTPUT_DIR) + for f in fnlist: + graph = f.toGraph() + model = onnx.helper.make_model(graph, producer_name='p2o', opset_imports=[onnx.helper.make_opsetid("", 15)]) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model) + onnx.save(model, os.path.join(TEST_OUTPUT_DIR, f.name + ".onnx")) + + def test_source_input(self): + script = """ +def square(x): + return onnx.Mul(x, x) +""" + self._convert(script) + + def test_onnxfns(self): + self._convert(os.path.join(CURRENT_DIR, "onnxfns.py")) + + def test_models(self): + self._convert_and_save(os.path.join(CURRENT_DIR, "onnxmodels.py")) + + def test_if_models(self): + self._convert_and_save(os.path.join(CURRENT_DIR, "if.py")) + + def test_loop_models(self): + self._convert_and_save(os.path.join(CURRENT_DIR, "loop.py")) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/onnxscript/test/dropout.py b/onnxscript/test/dropout.py new file mode 100644 index 0000000000..76135d31fa --- /dev/null +++ b/onnxscript/test/dropout.py @@ -0,0 +1,10 @@ + +def Dropout(data, ratio, training_mode, seed : int): + if (training_mode): + mask = onnx.ConstantOfShape(onnx.Shape(data), value = True) + output = data + else: + rand = onnx.RandomUniformLike(data, dtype=1, seed=seed) + mask = (rand >= ratio) + output = onnx.Where(mask, data, 0) / (1.0 - ratio) + return (output, mask) \ No newline at end of file diff --git a/onnxscript/test/eg1.py b/onnxscript/test/eg1.py new file mode 100644 index 0000000000..eac4258048 --- /dev/null +++ b/onnxscript/test/eg1.py @@ -0,0 +1,27 @@ + +from onnxscript.types import FLOAT + +# tensor inputs can have ONNX-like type annotations +def gemm (A : FLOAT[2048,124], W : FLOAT[124,4096], Bias : FLOAT[4096]) -> FLOAT[2048,4096] : + return onnx.MatMul(A, W) + Bias + +# tensors and attributes distinguished by their types +def scale (A : FLOAT[...], alpha: float, beta: float) -> FLOAT[...] : + return alpha * A + beta + +# can return multiple-values +def prodsum (A: FLOAT['N'], B: FLOAT['N']) -> (FLOAT['N'], FLOAT['N']) : + prod = A * B + sum = A + B + return prod, sum + +# can call ops/functions that return multiple-values +def dropout_eg(A: FLOAT[...]) -> FLOAT[...]: + output, mask = onnx.Dropout(A, 0.7, True, seed=1729) + return output + +# will rename variable assigned multiple times +def renaming(A : FLOAT["N"]) -> FLOAT["N"] : + T = onnx.Abs(A) + T = onnx.Neg(T) + return T \ No newline at end of file diff --git a/onnxscript/test/if.py b/onnxscript/test/if.py new file mode 100644 index 0000000000..902e8df85a --- /dev/null +++ b/onnxscript/test/if.py @@ -0,0 +1,31 @@ +from onnxscript.types import FLOAT + +def maxsum (A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + sum1 = onnx.ReduceSum(A) + sum2 = onnx.ReduceSum(B) + if (sum1 < sum2): + result = onnx.Identity(B) + else: + result = onnx.Identity(A) + return result + +# Test inference of inputs/outputs for then/else blocks: +def maxsum (A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + sum1 = onnx.ReduceSum(A) + sum2 = onnx.ReduceSum(B) + if (sum1 < sum2): + temp = onnx.Identity(B) + result = onnx.Identity(temp) + else: + temp = onnx.Identity(A) + result = onnx.Identity(temp) + return result + +# test variables assigned only in one branch +def maxsum2 (A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + sum1 = onnx.ReduceSum(A) + sum2 = onnx.ReduceSum(B) + result = onnx.Identity(A) + if (sum1 < sum2): + result = onnx.Identity(B) + return result diff --git a/onnxscript/test/loop.py b/onnxscript/test/loop.py new file mode 100644 index 0000000000..89a8e6f512 --- /dev/null +++ b/onnxscript/test/loop.py @@ -0,0 +1,19 @@ +from onnxscript.types import FLOAT, INT64 + +# simple loop +def sumprod (x : FLOAT['N'], N : INT64) -> (FLOAT['N'], FLOAT['N']) : + sum = onnx.Identity(x) + prod = onnx.Identity(x) + for i in range(N): + sum = sum + x + prod = prod * x + return sum, prod + +# loop: loop-bound as an expression +def sumprod2 (x : FLOAT['N'], N : INT64) -> (FLOAT['N'], FLOAT['N']) : + sum = onnx.Identity(x) + prod = onnx.Identity(x) + for i in range(2*N+1): + sum = sum + x + prod = prod * x + return sum, prod \ No newline at end of file diff --git a/onnxscript/test/multi.py b/onnxscript/test/multi.py new file mode 100644 index 0000000000..1967e4e9a2 --- /dev/null +++ b/onnxscript/test/multi.py @@ -0,0 +1,7 @@ + +from onnxscript.types import FLOAT + +def multi (A : FLOAT["N"]) -> FLOAT["N"] : + x, y = onnx.Foo (A) + x, y = onnx.Bar (A) + return x + y \ No newline at end of file diff --git a/onnxscript/test/onnxfns.py b/onnxscript/test/onnxfns.py new file mode 100644 index 0000000000..201442c8b2 --- /dev/null +++ b/onnxscript/test/onnxfns.py @@ -0,0 +1,39 @@ + +def Relu(X): + return onnx.Max (X, 0) + +# TODO: default-values of attributes +def Selu(X, alpha : float = 1.67326319217681884765625, gamma : float = 1.05070102214813232421875): + neg = gamma * (alpha * onnx.Exp(X) - alpha) + pos = gamma * X + return onnx.Where (X <= 0, neg, pos) + +def Elu(X, alpha : float = 1.0): + return onnx.Where (X < 0, alpha * (onnx.Exp(X) - 1.0), X) + +def ThresholdedRelu (X, alpha : float = 1.0): + return onnx.Where (X > alpha, X, 0) + +def LeakyRelu(X, alpha : float = 0.01): + return onnx.Where(X < 0, alpha * X, X) + +def PRelu(X, slope): + # future-work: capturing extra requirements such as: + # slope must be unidirectionally broadcastable to X's shape. + return onnx.Where(X < 0, slope * X, X) + +def HardSigmoid (X, alpha : float = 0.2, beta : float = 0.5): + return onnx.Max(0, onnx.Min(1, alpha * X + beta)) + +def Shrink(x, bias : float = 0.0, lambd : float = 0.5): + return onnx.Where(x < -lambd, x + bias, Onnx.Where(x > lambd, x - bias, 0)) + +def Softplus(X): + return onnx.Log(onnx.Exp(X) + 1) + +def Softsign(X): + return X / (1 + onnx.Abs(X)) + +def Clip(input, min, max): + # TODO: default values specified for min/max + return onnx.Where (input < min, min, onnx.Where(input > max, max, input)) \ No newline at end of file diff --git a/onnxscript/test/onnxmodels.py b/onnxscript/test/onnxmodels.py new file mode 100644 index 0000000000..3c4dd6b02c --- /dev/null +++ b/onnxscript/test/onnxmodels.py @@ -0,0 +1,71 @@ +from onnxscript.types import FLOAT + +# type annotation tests: + +def ta_test1(A : FLOAT[100], B: FLOAT[100]) -> FLOAT[100] : + return A + B + +def ta_test2(A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + return A + B + +# feature: constant promotion + +def plus1 (A: FLOAT["N"]) -> FLOAT["N"] : + return A + 1.0 + +# example models + +def eg1 (A : FLOAT[100, 200], W : FLOAT[200, 96], B : FLOAT[96]) -> FLOAT[100, 96]: + return onnx.MatMul(A, W) + B + +def gemmgelu ( + A : FLOAT[2048,16], + W : FLOAT[16,4096], + Bias : FLOAT[4096] + ) -> FLOAT[2048,4096] : + + a = onnx.Constant(value_float = 0.5) + b = onnx.Constant(value_float = 0.797885) + c = onnx.Constant(value_float = 0.035677) + one = onnx.Constant(value_float = 1.0) + P1 = onnx.MatMul(A, W) + X = onnx.Add(P1, Bias) + T1 = onnx.Mul(X, X) + T2 = onnx.Mul(c, T1) + T3 = onnx.Add(b, T2) + T4 = onnx.Mul(X, T3) + T5 = onnx.Tanh(T4) + T6 = onnx.Add(one, T5) + T7 = onnx.Mul(X, T6) + Y = onnx.Mul(a, T7) + return Y + +def gemmgelu2 ( + A : FLOAT[2048,16], + W : FLOAT[16,4096], + Bias : FLOAT[4096] + ) -> FLOAT[2048,4096] : + + a = onnx.Constant(value_float = 0.5) + b = onnx.Constant(value_float = 0.797885) + c = onnx.Constant(value_float = 0.035677) + one = onnx.Constant(value_float = 1.0) + P1 = onnx.MatMul(A, W) + X = onnx.MatMul(A, W) + Bias + T4 = X * (b + c * X * X) + Y = a * X * (onnx.Tanh(T4) + one) + return Y + +def gemmgelu3 ( + A : FLOAT[2048,16], + W : FLOAT[16,4096], + Bias : FLOAT[4096] + ) -> FLOAT[2048,4096] : + a = 0.5 + b = 0.797885 + c = 0.035677 + P1 = onnx.MatMul(A, W) + X = onnx.MatMul(A, W) + Bias + T4 = X * (b + c * X * X) + Y = a * X * (onnx.Tanh(T4) + 1.0) + return Y \ No newline at end of file diff --git a/onnxscript/test/renaming.py b/onnxscript/test/renaming.py new file mode 100644 index 0000000000..c1afcae34c --- /dev/null +++ b/onnxscript/test/renaming.py @@ -0,0 +1,14 @@ +from onnxscript.types import FLOAT + +# same variable assigned multiple times +def renaming(A : FLOAT["N"]) -> FLOAT["N"] : + T = onnx.Abs(A) + T = onnx.Neg(A) + return T + +# clash between generated-name and pre-existing name +def renaming2(A : FLOAT["N"]) -> FLOAT["N"] : + T_0 = onnx.Relu(A) + T = onnx.Abs(A) + T = onnx.Neg(A) + return T \ No newline at end of file diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py new file mode 100644 index 0000000000..391b49e046 --- /dev/null +++ b/onnxscript/type_annotation.py @@ -0,0 +1,20 @@ + +import onnx +import onnxscript + +pytype_to_attrtype_map = { + float : onnx.AttributeProto.FLOAT, + int : onnx.AttributeProto.INT, + str : onnx.AttributeProto.STRING, +} + +def is_attr(typeinfo): + return typeinfo in {float, int, str} + # (typeinfo is float) or (typeinfo is str) or (typeinfo is int) + +def is_tensor(typeinfo): + return hasattr(typeinfo, "toTypeProto") + # return isinstance(typeinfo, onnxscript.Tensor) # TODO + +def is_valid(typeinfo): + return is_attr(typeinfo) or is_tensor(typeinfo) \ No newline at end of file diff --git a/onnxscript/utils.py b/onnxscript/utils.py new file mode 100644 index 0000000000..42debe94e4 --- /dev/null +++ b/onnxscript/utils.py @@ -0,0 +1,23 @@ +from typing import Any +import onnxruntime as ort + +class Model: + def __init__(self, onnxfile) -> None: + self.session = ort.InferenceSession(onnxfile) + + def __call__(self, *args: Any, **kwds: Any) -> Any: + inputs = self.session.get_inputs() + for i, arg in enumerate(args): + kwds[inputs[i].name] = arg + return self.session.run(None, kwds) + +def make_ort_value(v): + pass + +def make_attrs(**kwds): + pass + +def call_ort (opname, *args, **kwds): + model = Model("todo") + ort_args = [make_ort_value(x) for x in args] + return model(ort_args) diff --git a/onnxscript/values.py b/onnxscript/values.py new file mode 100644 index 0000000000..45b7f98896 --- /dev/null +++ b/onnxscript/values.py @@ -0,0 +1,59 @@ +import onnx + +# ONNX opsets (correspond to python modules in reference-mode) +# Have a domain-name, version, and a list of ops + +class Opset: + def __init__(self, domain, version) -> None: + self.domain = domain + self.version = version + + def __getitem__ (self, opname): + return onnx.defs.get_schema (opname, self.version, self.domain) + + def __contains__ (self, opname): + try: + onnx.defs.get_schema (opname, self.version, self.domain) + return True + except: + return False + + def __str__(self) -> str: + return self.domain + +opset15 = Opset("", 15) + +# ONNX ops + +class Op: + def __init__(self, opset, opname) -> None: + self.opset = opset + self.opname = opname + + def get_schema(self): + return self.opset[self.opname] + + def has_schema(self): + return (self.opname in self.opset) + +# Values fall into the following categories: +# ConstValue: values known at translation-time, mapped to ONNX attributes +# AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes +# Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs + +class Value: + def __init__(self, val) -> None: + self.value = val + +class ConstValue(Value): + def __init__(self, val) -> None: + super().__init__(val) + +class AttrRef(Value): + def __init__(self, name: str, type) -> None: + super().__init__(name) + self.type = type + +class Dynamic(Value): + def __init__(self, val) -> None: + super().__init__(val) From 13753ce4c824ed3efd8385d3ad8806c1e34aeed3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 28 Feb 2022 16:42:43 -0800 Subject: [PATCH 2/6] Patch for msdomain Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 2 +- onnxscript/irbuilder.py | 2 +- onnxscript/values.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 62f72ee845..f67f5ecce3 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -72,7 +72,7 @@ class Converter: def __init__(self, ir_builder = IRBuilder()): self.ir_builder = ir_builder self.known_modules = { 'onnxscript' : onnxscript, 'onnxscript.types' : types, 'onnx.opset15' : values.opset15 } - self.globals = { "int" : int, "float" : float, "str" : str, "onnx" : values.opset15, "Onnx" : values.opset15 } # 'os' : onnxscript + self.globals = { "int" : int, "float" : float, "str" : str, "onnx" : values.opset15, "Onnx" : values.opset15, "msdomain" : values.msdomain1 } # 'os' : onnxscript self.pure_modules = ["onnxscript"] self.default_type = types.FLOAT[...] diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 60b88ecaa0..a9449db190 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -67,7 +67,7 @@ def __str__(self): args = format (self.args, "(", ", ", ")") module = str(self.module) callee = module + "." + self.opname if (module != '') else self.opname - return (lhs + " = " + self.opname + " " + attrs + args) + return (lhs + " = " + callee + " " + attrs + args) def print(self): print (str(self)) diff --git a/onnxscript/values.py b/onnxscript/values.py index 45b7f98896..708637a297 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -23,6 +23,8 @@ def __str__(self) -> str: opset15 = Opset("", 15) +msdomain1 = Opset("com.microsoft", 1) + # ONNX ops class Op: From b5c7ee816fbf9d844ee14ea364ee58b55769824e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 1 Mar 2022 15:35:15 -0800 Subject: [PATCH 3/6] switch to uniform casing convention Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 233 ++++++++++++++---------------- onnxscript/test/converter_test.py | 7 + 2 files changed, 119 insertions(+), 121 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index f67f5ecce3..555107de9d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1,18 +1,16 @@ -from numpy import isin -import onnxscript -import onnxscript.types as types import os import inspect import ast -from ast import * -import irbuilder +import onnxscript +import onnxscript.types as types + from irbuilder import IRBuilder import analysis import type_annotation as ta import values -from values import Value, ConstValue, AttrRef, Dynamic, Op +from values import ConstValue, AttrRef, Dynamic, Op import onnx import onnx.helper as helper @@ -42,6 +40,8 @@ def ignore (cond, msg): if cond: warn (msg) +# Utility to convert a python value to TensorProto: + def pyvalue_to_tensor(tensor_name : str, pyvalue): if isinstance(pyvalue, bool): return helper.make_tensor(tensor_name, onnx.TensorProto.BOOL, [], [int(pyvalue)]) @@ -50,6 +50,7 @@ def pyvalue_to_tensor(tensor_name : str, pyvalue): elif isinstance(pyvalue, float): return helper.make_tensor(tensor_name, onnx.TensorProto.FLOAT, [], [pyvalue]) else: + # TODO: str, sequences of values fail("Unimplemented") # map from python operators to ONNX ops @@ -76,26 +77,27 @@ def __init__(self, ir_builder = IRBuilder()): self.pure_modules = ["onnxscript"] self.default_type = types.FLOAT[...] - def initFunctionTranslation(self): + def init_function_translation(self): + """Initialize self for translating a new function.""" self.outer = [] self.current_fn = None self.nextvar = 0 self.used_vars = set() self.locals = [{}] - def enterScope(self, name): + def enter_scope(self, name): self.outer.insert(0,self.current_fn) self.current_fn = self.ir_builder.newFunction(name) self.locals.insert(0, {}) - def exitScope(self): + def exit_scope(self): graph = self.current_fn self.current_fn = self.outer[0] self.outer.pop(0) self.locals.pop(0) return graph - def currentScope(self): + def current_scope(self): return self.locals[0] def bind(self, name, val): @@ -107,7 +109,7 @@ def lookup(self, name): if (name in self.globals): return self.globals[name] raise ValueError("Unbound name: " + name) - def generateUniqueName(self, candidate = "tmp"): + def generate_unique_name(self, candidate = "tmp"): r = candidate while (r in self.used_vars): r = candidate + "_" + str(self.nextvar) @@ -123,13 +125,13 @@ def to_onnx_attr_ref(self, val : AttrRef): def to_onnx_var(self, val, target = None): if (isinstance(val, AttrRef)): # promote attribute to value - result = self.generateUniqueName(target if target else "tmp") + result = self.generate_unique_name(target if target else "tmp") attr = self.to_onnx_attr_ref(val) self.emit ([result], Op("", "Constant"), [], [attr]) return result elif (isinstance(val, ConstValue) and isinstance(val.value, float)): # TODO - result = self.generateUniqueName(target if target else "tmp") - return self.emitConst(val.value, result) + result = self.generate_unique_name(target if target else "tmp") + return self.emit_const(val.value, result) elif isinstance(val, Dynamic): return val.value else: @@ -142,33 +144,33 @@ def emit(self, outputs, callee, inputs, attrs): def emit2(self, outputs, callee, inputs, attrs): def rename(x): - r = self.generateUniqueName(x) + r = self.generate_unique_name(x) self.bind(x, Dynamic(r)) return r onnx_inputs = inputs # [ self.to_onnx_var(self.lookup(pvar)) for pvar in inputs ] onnx_outputs = [ rename(x) for x in outputs ] self.emit(onnx_outputs, Op("", callee), onnx_inputs, attrs) - def emitConst (self, pyvalue, suggested_name): - ovar = self.generateUniqueName(suggested_name) + def emit_const (self, pyvalue, suggested_name): + ovar = self.generate_unique_name(suggested_name) tensor = pyvalue_to_tensor (ovar, pyvalue) attr = self.ir_builder.attr("value", tensor) self.emit([ovar], Op("", "Constant"), [], [attr]) return ovar - def isPureModule(self, m): + def is_pure_module(self, m): return (m in self.pure_modules) - def isConstantExpr (self, node): + def is_constant_expr (self, node): if (isinstance(node, ast.Name)): val = self.lookup(node.id) - return isinstance(val, ConstValue) and self.isPureModule(val.value) + return isinstance(val, ConstValue) and self.is_pure_module(val.value) if isinstance(node, (ast.Call, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Num, ast.Str, ast.Attribute)): - return all([self.isConstantExpr(c) for c in ast.iter_child_nodes(node)]) + return all([self.is_constant_expr(c) for c in ast.iter_child_nodes(node)]) return False - def evalConstantExpr(self, node): - # TODO: assert (self.isConstantExpr(node)) + def eval_constant_expr(self, node): + # TODO: assert (self.is_constant_expr(node)) locals = {} # TODO return eval(compile(ast.Expression(node), filename="", mode="eval"), self.globals, locals) @@ -187,11 +189,11 @@ def eval_attr(self, node): elif (isinstance(node, ast.List)): return [self.eval_attr(x) for x in node.elts] elif (isinstance(node, (ast.Call, ast.Attribute))): - return self.evalConstantExpr(node) + return self.eval_constant_expr(node) else: raise ValueError("Unsupported attribute type: " + type(node).__name__) - def translateAttr(self, attr_name, node): + def translate_attr(self, attr_name, node): if (isinstance(node, ast.Name)): val = self.lookup(node.id) if (isinstance(val, AttrRef)): @@ -203,79 +205,74 @@ def translateAttr(self, attr_name, node): # Expression-translation generates "IR statements/nodes" that compute the value of # the expression into a target-variable, and returns the variable that is assigned this value. - def translateExpr(self, node, target="tmp"): + def translate_expr(self, node, target="tmp"): if (isinstance(node, ast.Call)): - r = self.translateCall(node) + r = self.translate_call_expr(node) elif (isinstance(node, ast.BinOp)): - r = self.translateBinOp (node) + r = self.translate_bin_op_expr (node) elif (isinstance(node, ast.UnaryOp)): - r = self.translateUnaryOp (node) + r = self.translate_unary_op_expr (node) elif (isinstance(node, ast.Compare)): - r = self.translateCompare (node) + r = self.translate_compare_expr (node) elif (isinstance(node, ast.Name)): - r = self.translateName (node) + r = self.translate_name_expr (node) elif (isinstance(node, ast.Num)): - r = self.emitConst(node.n, target) + r = self.emit_const(node.n, target) elif isinstance(node, ast.NameConstant): - r = self.emitConst(node.value, target) - # elif (isinstance(node, ast.Attribute)): - # r = self.translateAttribute (node) + r = self.emit_const(node.value, target) else: raise ValueError("Unsupported expression type: " + type(node).__name__) if (isinstance(r, tuple)): if isinstance(target, str): - result = self.generateUniqueName(target) + result = self.generate_unique_name(target) callee, args, attrs = r self.emit([result], callee, args, attrs) return result else: assert isinstance(target, list) - results = [self.generateUniqueName(x) for x in target] + results = [self.generate_unique_name(x) for x in target] callee, args, attrs = r self.emit(results, callee, args, attrs) return results return r - def translateCall(self, node): + def translate_call_expr(self, node): # TODO: for now, we map named arguments to attributes, and positional arguments to inputs. - callee = self.translateCallee(node.func) - args = [self.translateExpr(x) for x in node.args] - attrs = [self.translateAttr(x.arg, x.value) for x in node.keywords] + callee = self.translate_callee_expr(node.func) + args = [self.translate_expr(x) for x in node.args] + attrs = [self.translate_attr(x.arg, x.value) for x in node.keywords] return (callee, args, attrs) - def translateBinOp(self, node): + def translate_bin_op_expr(self, node): op = type(node.op) assert (op in primop_map) opname = primop_map[op] - left = self.translateExpr(node.left) - right = self.translateExpr(node.right) + left = self.translate_expr(node.left) + right = self.translate_expr(node.right) return (Op("", opname), [left, right], []) - def translateUnaryOp(self, node): + def translate_unary_op_expr(self, node): op = type(node.op) assert (op in primop_map) opname = primop_map[op] - operand = self.translateExpr(node.operand) + operand = self.translate_expr(node.operand) return (Op("", opname), [operand], []) - def translateCompare(self, node): + def translate_compare_expr(self, node): assert (len(node.ops) == 1) # TODO: handle multiple comparisons in one expression assert (len(node.comparators) == 1) op = type(node.ops[0]) assert (op in primop_map) opname = primop_map[op] - left = self.translateExpr(node.left) - right = self.translateExpr(node.comparators[0]) + left = self.translate_expr(node.left) + right = self.translate_expr(node.comparators[0]) return (Op("", opname), [left, right], []) - # TODO: returns??? - def translateName(self, node): + def translate_name_expr(self, node): return self.py_var_to_onnx_var(node.id) - def translateNum(self, node): - return self.emitConst(node.n, "Const") - - def translateModule(self, node): + def translate_opset_expr(self, node) -> values.Opset : + """Return an Opset""" if isinstance(node, ast.Name): try: val = self.lookup(node.id) @@ -283,18 +280,19 @@ def translateModule(self, node): val = val.value if isinstance(val, values.Opset): return val - fail(f"{node.id} has value of type {type(node.id)} and used as module") + fail(f"{node.id} has value of type {type(node.id)} and used as opset.") except: - warn(f"Unknown module name {node.id}.") + warn(f"Unknown opset name {node.id}.") return values.Opset(node.id, 1) elif isinstance (node, ast.Attribute): - fail("Nested module unimplemented") + fail("Nested module unimplemented") # TODO else: - fail("Invalid module.") + fail("Invalid opset expression.") - def translateCallee(self, node): + def translate_callee_expr(self, node) -> values.Op : + """Return an Op""" if isinstance(node, ast.Attribute): - module = self.translateModule(node.value) + module = self.translate_opset_expr(node.value) opname = node.attr if (opname not in module): warn (f"{opname} is not a known op in {str(module)}") return Op(module, node.attr) @@ -311,40 +309,40 @@ def translateCallee(self, node): # Statement translation: A single Python statement is mapped into a sequence of IR statements. - def translateStmt(self, node): + def translate_stmt(self, node): if (isinstance(node, ast.Assign)): - self.translateAssign(node) + self.translate_assign_stmt(node) elif (isinstance(node, ast.Return)): - self.translateReturn(node) + self.translate_return_stmt(node) elif (isinstance(node, ast.If)): - self.translateIf(node) + self.translate_if_stmt(node) elif isinstance(node, ast.For): - self.translateFor(node) + self.translate_for_stmt(node) else: raise ValueError("Unsupported statement type: " + type(node).__name__) - def translateAssign(self, node: Assign): + def translate_assign_stmt(self, stmt: ast.Assign): def assign(lhs, rhs): if (isinstance(lhs, ast.Name)): lhs = lhs.id - if (self.isConstantExpr(rhs)): - self.bind(lhs, ConstValue(self.evalConstantExpr(rhs))) + if (self.is_constant_expr(rhs)): + self.bind(lhs, ConstValue(self.eval_constant_expr(rhs))) else: - t = self.translateExpr(rhs, lhs) + t = self.translate_expr(rhs, lhs) self.bind(lhs, Dynamic(t)) elif isinstance(lhs, ast.Tuple): def id(x): assert isinstance(x, ast.Name) return x.id ids = [id(x) for x in lhs.elts] - onnxids = self.translateExpr(rhs, ids) + onnxids = self.translate_expr(rhs, ids) for x, y in zip(ids, onnxids): self.bind(x, Dynamic(y)) else: fail("Unsupported construct in LHS of assignment.") - assert (len(node.targets) == 1), "Multi-assignment not supported." - lhs = node.targets[0] - rhs = node.value + assert (len(stmt.targets) == 1), "Multi-assignment not supported." + lhs = stmt.targets[0] + rhs = stmt.value if (isinstance(rhs, ast.Tuple)): assert isinstance(lhs, ast.Tuple) assert len(lhs.elts) == len(rhs.elts), "Expected same number of elements on lhs and rhs of assignments." @@ -353,9 +351,9 @@ def id(x): else: assign(lhs, rhs) - def translateReturn(self, node): + def translate_return_stmt(self, stmt: ast.Return): def ret(exp, suffix=""): - ovar = self.translateExpr(exp, "return_val" + suffix) + ovar = self.translate_expr(exp, "return_val" + suffix) # if hasattr(self, returntype) and self.num_outputs < len(self.returntype): try: t = self.returntype[self.num_outputs] @@ -365,29 +363,29 @@ def ret(exp, suffix=""): self.num_outputs += 1 return ovar - val = node.value + val = stmt.value assert (val != None), "Return statement without return-value not supported." if (isinstance(val, ast.Tuple)): return [ret(exp,str(i)) for i,exp in enumerate(val.elts)] else: return ret(val) - def translateIf(self, node): - live_defs = list(node.live_out.intersection(analysis.defs(node))) + def translate_if_stmt(self, stmt: ast.If): + live_defs = list(stmt.live_out.intersection(analysis.defs(stmt))) # print(live_defs) - test = self.translateExpr(node.test, "cond") - thenGraph = self.translateBlock(node.body, "thenGraph", live_defs) + test = self.translate_expr(stmt.test, "cond") + thenGraph = self.translate_block(stmt.body, "thenGraph", live_defs) thenAttr = self.ir_builder.attr("then_branch", thenGraph) - elseGraph = self.translateBlock(node.orelse, "elseGraph", live_defs) + elseGraph = self.translate_block(stmt.orelse, "elseGraph", live_defs) elseAttr = self.ir_builder.attr("else_branch", elseGraph) def rename(x): - r = self.generateUniqueName(x) + r = self.generate_unique_name(x) self.bind(x, Dynamic(r)) return r renamed = [ rename(x) for x in live_defs ] self.emit(renamed, Op("", "If"), [test], [thenAttr, elseAttr]) - def translateFor(self, for_stmt: ast.For): + def translate_for_stmt(self, for_stmt: ast.For): # loop-variable assert isinstance(for_stmt.target, ast.Name), "For loop target must be a single variable." p_loop_var = for_stmt.target.id @@ -398,7 +396,7 @@ def translateFor(self, for_stmt: ast.For): assert iter.func.id == "range", "Unsupported loop bound." assert iter.args and len(iter.args) == 1, "Unsupported loop bound." assert not iter.keywords, "Unsupported loop bound." - o_loop_bound = self.translateExpr(iter.args[0], "loop_bound") + o_loop_bound = self.translate_expr(iter.args[0], "loop_bound") # analyze loop body exposed_uses = analysis.exposed_uses(for_stmt.body) vars_def_in_loop = analysis.defs(for_stmt.body) @@ -407,75 +405,72 @@ def translateFor(self, for_stmt: ast.For): outputs = list(loop_state_vars | scan_outputs) # loop-condition: - o_true = self.emitConst(True, "true") - # o_loop_bound = self.emitConst(3, "loop_bound") + o_true = self.emit_const(True, "true") + # o_loop_bound = self.emit_const(3, "loop_bound") # build loop_body - self.enterScope("loop_body") - o_loop_var = self.generateUniqueName(p_loop_var) + self.enter_scope("loop_body") + o_loop_var = self.generate_unique_name(p_loop_var) self.ir_builder.addInput(self.current_fn, o_loop_var, types.INT64) self.bind(p_loop_var, Dynamic(o_loop_var)) - o_cond_var = self.generateUniqueName("cond_in") + o_cond_var = self.generate_unique_name("cond_in") self.ir_builder.addInput(self.current_fn, o_cond_var, types.BOOL) for pv in loop_state_vars: - ov = self.generateUniqueName(pv) + ov = self.generate_unique_name(pv) self.ir_builder.addInput(self.current_fn, ov, self.default_type) self.bind(pv, Dynamic(ov)) for s in for_stmt.body: - self.translateStmt(s) - o_cond_out = self.generateUniqueName("cond_out") + self.translate_stmt(s) + o_cond_out = self.generate_unique_name("cond_out") self.emit([o_cond_out], Op("", "Identity"), [o_cond_var], []) self.ir_builder.addOutput(self.current_fn, o_cond_out, types.BOOL) for pv in loop_state_vars: ov = self.py_var_to_onnx_var(pv) self.ir_builder.addOutput(self.current_fn, ov, self.default_type) # TODO: type - body = self.exitScope() - # if (print_flag): - # print("Generated loop body:") - # body.print() + body = self.exit_scope() inputs = [o_loop_bound, o_true] + [self.py_var_to_onnx_var(pv) for pv in loop_state_vars] attrs = [self.ir_builder.attr("body", body.toGraph())] self.emit2(outputs, "Loop", inputs, attrs) # Translation of a statement-block to GraphProto attribute - def translateBlock(self, stmts, name, live_defs): - self.enterScope(name) + def translate_block(self, stmts, name, live_defs): + self.enter_scope(name) for s in stmts: - self.translateStmt(s) + self.translate_stmt(s) for pvar in live_defs: - if (pvar in self.currentScope()): - pv_val = self.currentScope()[pvar] + if (pvar in self.current_scope()): + pv_val = self.current_scope()[pvar] output = self.to_onnx_var(pv_val, pvar) self.ir_builder.addOutput(self.current_fn, output, self.default_type) # TODO: need type! else: pv_val = None - for scope in self.locals: # TODO: skip currentScope + for scope in self.locals: # TODO: skip current_scope if (pvar in scope): pv_val = scope[pvar] break if (pv_val is None): fail (f"Variable {pvar} is not assigned a value along a conditional branch.") # introduce a copy - ovar = self.generateUniqueName(pvar) + ovar = self.generate_unique_name(pvar) self.emit([ovar], Op("", "Identity"), [self.to_onnx_var(pv_val, pvar)], []) self.ir_builder.addOutput(self.current_fn, ovar, self.default_type) # TODO: need type! - graph = self.exitScope() + graph = self.exit_scope() # if print_flag: # print ("Generated block") # graph.print() return graph.toGraph() - def convert_FunctionDef(self, node): - args = node.args + def translate_function_def(self, fn: ast.FunctionDef): + args = fn.args if (args.defaults): - warn (f"{node.name}: Default values not yet implemented.") + warn (f"{fn.name}: Default values not yet implemented.") if (args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg): - warn (f"{node.name}: Unsupported feature in function signature.") - self.current_fn = self.ir_builder.newFunction(node.name) + warn (f"{fn.name}: Unsupported feature in function signature.") + self.current_fn = self.ir_builder.newFunction(fn.name) for x in args.args: if x.annotation: - typeinfo = self.evalConstantExpr(x.annotation) + typeinfo = self.eval_constant_expr(x.annotation) else: typeinfo = self.default_type assert ta.is_valid(typeinfo) @@ -485,8 +480,8 @@ def convert_FunctionDef(self, node): else: self.ir_builder.addInput(self.current_fn, x.arg, typeinfo) self.bind(x.arg, Dynamic(x.arg)) - if node.returns: - returntype = self.evalConstantExpr(node.returns) + if fn.returns: + returntype = self.eval_constant_expr(fn.returns) if isinstance(returntype, tuple): assert all([ta.is_valid(t) for t in returntype]) self.returntype = returntype @@ -496,8 +491,8 @@ def convert_FunctionDef(self, node): else: self.returntype = None self.num_outputs = 0 - for s in node.body: - self.translateStmt(s) + for s in fn.body: + self.translate_stmt(s) if self.returntype is not None: assert (self.num_outputs == len(self.returntype)), "Mismatch in number of return values and types" return self.current_fn @@ -511,9 +506,9 @@ def do_import(self, alias): def top_level_stmt (self, stmt): if isinstance(stmt, ast.FunctionDef): - self.initFunctionTranslation() + self.init_function_translation() analysis.do_liveness_analysis(stmt) - fn_ir = self.convert_FunctionDef(stmt) + fn_ir = self.translate_function_def(stmt) if print_flag: print("============== OUTPUT =============") fn_ir.print() @@ -533,10 +528,6 @@ def top_level_stmt (self, stmt): raise ValueError("Unsupported top-level statement type: " + type(stmt).__name__) def convert_source(self, src): - if print_flag: - print("============== INPUT =============") - print(src) - print() module = ast.parse(src) assert type(module) == ast.Module converted = [self.top_level_stmt(d) for d in module.body] diff --git a/onnxscript/test/converter_test.py b/onnxscript/test/converter_test.py index 976c806131..cc46a156e3 100644 --- a/onnxscript/test/converter_test.py +++ b/onnxscript/test/converter_test.py @@ -30,6 +30,13 @@ def square(x): """ self._convert(script) + def test_msdomain(self): + # Temporary patch to use com.microsoft domain + script = """ +def foo(x): + return msdomain.bar(x, x) +""" + self._convert(script) def test_onnxfns(self): self._convert(os.path.join(CURRENT_DIR, "onnxfns.py")) From f1654025bfdf0ca289a1acd10af2582dcac07780 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 1 Mar 2022 16:21:47 -0800 Subject: [PATCH 4/6] cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/analysis.py | 2 -- onnxscript/convert.py | 2 ++ onnxscript/converter.py | 33 +++++++++++-------------- onnxscript/irbuilder.py | 32 ++++++++++++------------ onnxscript/onnxscript/types/__init__.py | 4 +-- onnxscript/test/converter_test.py | 2 +- onnxscript/type_annotation.py | 2 +- onnxscript/utils.py | 2 ++ 8 files changed, 39 insertions(+), 40 deletions(-) diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py index 854eb00d5a..c99cc75685 100644 --- a/onnxscript/analysis.py +++ b/onnxscript/analysis.py @@ -1,7 +1,5 @@ import ast -from numpy import isin - def used_vars(expr): ''' Return set of all variables used with an expression.''' if (isinstance(expr, ast.Name)): diff --git a/onnxscript/convert.py b/onnxscript/convert.py index bd9067c19e..b33a98580d 100644 --- a/onnxscript/convert.py +++ b/onnxscript/convert.py @@ -1,6 +1,8 @@ import sys from converter import convert +# command-line utility to invoke converter on a python file + if __name__ == '__main__': for i in range(1, len(sys.argv)): convert(sys.argv[i]) \ No newline at end of file diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 555107de9d..766bd45599 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -87,7 +87,7 @@ def init_function_translation(self): def enter_scope(self, name): self.outer.insert(0,self.current_fn) - self.current_fn = self.ir_builder.newFunction(name) + self.current_fn = self.ir_builder.new_function(name) self.locals.insert(0, {}) def exit_scope(self): @@ -140,7 +140,7 @@ def to_onnx_var(self, val, target = None): def py_var_to_onnx_var(self, py_var): return self.to_onnx_var(self.lookup(py_var)) def emit(self, outputs, callee, inputs, attrs): - self.ir_builder.addStmt(self.current_fn, outputs, callee.opset, callee.opname, inputs, attrs) + self.ir_builder.add_stmt(self.current_fn, outputs, callee.opset, callee.opname, inputs, attrs) def emit2(self, outputs, callee, inputs, attrs): def rename(x): @@ -359,7 +359,7 @@ def ret(exp, suffix=""): t = self.returntype[self.num_outputs] except Exception as e: t = self.default_type - self.ir_builder.addOutput(self.current_fn, ovar, t) + self.ir_builder.add_output(self.current_fn, ovar, t) self.num_outputs += 1 return ovar @@ -411,26 +411,26 @@ def translate_for_stmt(self, for_stmt: ast.For): # build loop_body self.enter_scope("loop_body") o_loop_var = self.generate_unique_name(p_loop_var) - self.ir_builder.addInput(self.current_fn, o_loop_var, types.INT64) + self.ir_builder.add_input(self.current_fn, o_loop_var, types.INT64) self.bind(p_loop_var, Dynamic(o_loop_var)) o_cond_var = self.generate_unique_name("cond_in") - self.ir_builder.addInput(self.current_fn, o_cond_var, types.BOOL) + self.ir_builder.add_input(self.current_fn, o_cond_var, types.BOOL) for pv in loop_state_vars: ov = self.generate_unique_name(pv) - self.ir_builder.addInput(self.current_fn, ov, self.default_type) + self.ir_builder.add_input(self.current_fn, ov, self.default_type) self.bind(pv, Dynamic(ov)) for s in for_stmt.body: self.translate_stmt(s) o_cond_out = self.generate_unique_name("cond_out") self.emit([o_cond_out], Op("", "Identity"), [o_cond_var], []) - self.ir_builder.addOutput(self.current_fn, o_cond_out, types.BOOL) + self.ir_builder.add_output(self.current_fn, o_cond_out, types.BOOL) for pv in loop_state_vars: ov = self.py_var_to_onnx_var(pv) - self.ir_builder.addOutput(self.current_fn, ov, self.default_type) # TODO: type + self.ir_builder.add_output(self.current_fn, ov, self.default_type) # TODO: type body = self.exit_scope() inputs = [o_loop_bound, o_true] + [self.py_var_to_onnx_var(pv) for pv in loop_state_vars] - attrs = [self.ir_builder.attr("body", body.toGraph())] + attrs = [self.ir_builder.attr("body", body.to_graph_proto())] self.emit2(outputs, "Loop", inputs, attrs) # Translation of a statement-block to GraphProto attribute @@ -442,7 +442,7 @@ def translate_block(self, stmts, name, live_defs): if (pvar in self.current_scope()): pv_val = self.current_scope()[pvar] output = self.to_onnx_var(pv_val, pvar) - self.ir_builder.addOutput(self.current_fn, output, self.default_type) # TODO: need type! + self.ir_builder.add_output(self.current_fn, output, self.default_type) # TODO: need type! else: pv_val = None for scope in self.locals: # TODO: skip current_scope @@ -454,12 +454,9 @@ def translate_block(self, stmts, name, live_defs): # introduce a copy ovar = self.generate_unique_name(pvar) self.emit([ovar], Op("", "Identity"), [self.to_onnx_var(pv_val, pvar)], []) - self.ir_builder.addOutput(self.current_fn, ovar, self.default_type) # TODO: need type! + self.ir_builder.add_output(self.current_fn, ovar, self.default_type) # TODO: need type! graph = self.exit_scope() - # if print_flag: - # print ("Generated block") - # graph.print() - return graph.toGraph() + return graph.to_graph_proto() def translate_function_def(self, fn: ast.FunctionDef): args = fn.args @@ -467,7 +464,7 @@ def translate_function_def(self, fn: ast.FunctionDef): warn (f"{fn.name}: Default values not yet implemented.") if (args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg): warn (f"{fn.name}: Unsupported feature in function signature.") - self.current_fn = self.ir_builder.newFunction(fn.name) + self.current_fn = self.ir_builder.new_function(fn.name) for x in args.args: if x.annotation: typeinfo = self.eval_constant_expr(x.annotation) @@ -475,10 +472,10 @@ def translate_function_def(self, fn: ast.FunctionDef): typeinfo = self.default_type assert ta.is_valid(typeinfo) if (ta.is_attr(typeinfo)): - self.ir_builder.addAttr(self.current_fn, x.arg, typeinfo) + self.ir_builder.add_attr(self.current_fn, x.arg, typeinfo) self.bind(x.arg, AttrRef(x.arg, typeinfo)) else: - self.ir_builder.addInput(self.current_fn, x.arg, typeinfo) + self.ir_builder.add_input(self.current_fn, x.arg, typeinfo) self.bind(x.arg, Dynamic(x.arg)) if fn.returns: returntype = self.eval_constant_expr(fn.returns) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index a9449db190..60dfd75469 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -15,7 +15,7 @@ def __init__(self) -> None: self.onnx_type = tp # helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [10]) - def toTypeProto(self): + def to_type_proto(self): return self.onnx_type def __str__(self) -> str: @@ -29,11 +29,11 @@ def __init__(self, varname, type = None) -> None: def __str__(self): return self.name - def typedstr(self): + def typed_str(self): return self.name + " : " + str(self.type) - def toValueInfo(self): - tp = self.type.toTypeProto() + def to_value_info(self): + tp = self.type.to_type_proto() # if (not tp.tensor_type.HasField('shape')): # # TODO: temporary patch to export a function as a graph # tp = helper.make_tensor_type_proto(tp.tensor_type.elem_type, [10]) @@ -72,7 +72,7 @@ def __str__(self): def print(self): print (str(self)) - def toNode(self): + def to_node_proto(self): n = helper.make_node(self.opname, [str(x) for x in self.args], [str(x) for x in self.result]) @@ -90,8 +90,8 @@ def __init__(self, name) -> None: def __str__(self): attrs = format (self.attrs, "<", ", ", ">") if self.attrs else "" - inputs = format ([x.typedstr() for x in self.inputs], "(", ", ", ")") - outputs = format ([x.typedstr() for x in self.outputs], "(", ", ", ")") + inputs = format ([x.typed_str() for x in self.inputs], "(", ", ", ")") + outputs = format ([x.typed_str() for x in self.outputs], "(", ", ", ")") stmts = format (self.stmts, "\n{\n ", "\n ", "\n}\n") return (self.name + " " + attrs + inputs + " => " + outputs + stmts) @@ -103,32 +103,32 @@ def print(self): print(helper.printable_graph(attr.attr_proto.g)) - def toGraph(self): - return helper.make_graph([s.toNode() for s in self.stmts], + def to_graph_proto(self): + return helper.make_graph([s.to_node_proto() for s in self.stmts], self.name, - [x.toValueInfo() for x in self.inputs], - [y.toValueInfo() for y in self.outputs] + [x.to_value_info() for x in self.inputs], + [y.to_value_info() for y in self.outputs] ) # IRBuilder: abstracts out details of the IR in the python-to-IR converter class IRBuilder: - def newFunction(self, name): + def new_function(self, name): return Function(name) - def addStmt(self, fn, results, module, opname, args, attrs): + def add_stmt(self, fn, results, module, opname, args, attrs): s = Stmt(results, module, opname, args, attrs) fn.stmts.append(s) - def addInput(self, fn, varname, type): + def add_input(self, fn, varname, type): v = Var(varname, type) fn.inputs.append(v) - def addAttr(self, fn, varname, type): + def add_attr(self, fn, varname, type): v = Var(varname, type) fn.attrs.append(v) - def addOutput(self, fn, varname, type): + def add_output(self, fn, varname, type): v = Var(varname, type) fn.outputs.append(v) diff --git a/onnxscript/onnxscript/types/__init__.py b/onnxscript/onnxscript/types/__init__.py index ef99ee1e2b..cf5f00003d 100644 --- a/onnxscript/onnxscript/types/__init__.py +++ b/onnxscript/onnxscript/types/__init__.py @@ -12,7 +12,7 @@ def __str__(self) -> str: shapestr = str(self.shape) if self.shape else "[...]" return onnx.TensorProto.DataType.Name(self.dtype) + shapestr - def toTypeProto(self): + def to_type_proto(self): # TODO: handle None return onnx.helper.make_tensor_type_proto(self.dtype, self.shape) @@ -45,7 +45,7 @@ def mk_dim(dim): s = [shape] return Tensor(self.dtype, s) - def toTypeProto(self): + def to_type_proto(self): return onnx.helper.make_tensor_type_proto(self.dtype, ()) def __str__(self) -> str: diff --git a/onnxscript/test/converter_test.py b/onnxscript/test/converter_test.py index cc46a156e3..83b758fdd6 100644 --- a/onnxscript/test/converter_test.py +++ b/onnxscript/test/converter_test.py @@ -17,7 +17,7 @@ def _convert_and_save (self, script): if (not os.path.exists(TEST_OUTPUT_DIR)): os.makedirs(TEST_OUTPUT_DIR) for f in fnlist: - graph = f.toGraph() + graph = f.to_graph_proto() model = onnx.helper.make_model(graph, producer_name='p2o', opset_imports=[onnx.helper.make_opsetid("", 15)]) model = onnx.shape_inference.infer_shapes(model) onnx.checker.check_model(model) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 391b49e046..e19eec2987 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -13,7 +13,7 @@ def is_attr(typeinfo): # (typeinfo is float) or (typeinfo is str) or (typeinfo is int) def is_tensor(typeinfo): - return hasattr(typeinfo, "toTypeProto") + return hasattr(typeinfo, "to_type_proto") # return isinstance(typeinfo, onnxscript.Tensor) # TODO def is_valid(typeinfo): diff --git a/onnxscript/utils.py b/onnxscript/utils.py index 42debe94e4..b1fda109dc 100644 --- a/onnxscript/utils.py +++ b/onnxscript/utils.py @@ -1,6 +1,8 @@ from typing import Any import onnxruntime as ort +# TODO: enable invocation of ORT kernels + class Model: def __init__(self, onnxfile) -> None: self.session = ort.InferenceSession(onnxfile) From d96818d65d00b7d9e0cc1c4dcbe86f56d9e16f78 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Mar 2022 10:39:01 -0800 Subject: [PATCH 5/6] Python formatting Signed-off-by: Ganesan Ramalingam --- onnxscript/analysis.py | 16 ++-- onnxscript/convert.py | 2 +- onnxscript/converter.py | 149 +++++++++++++++++++--------------- onnxscript/irbuilder.py | 64 ++++++++------- onnxscript/type_annotation.py | 11 ++- onnxscript/utils.py | 8 +- onnxscript/values.py | 25 ++++-- 7 files changed, 157 insertions(+), 118 deletions(-) diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py index c99cc75685..b9c8e6e2b6 100644 --- a/onnxscript/analysis.py +++ b/onnxscript/analysis.py @@ -1,5 +1,6 @@ import ast + def used_vars(expr): ''' Return set of all variables used with an expression.''' if (isinstance(expr, ast.Name)): @@ -9,6 +10,7 @@ def used_vars(expr): result = result | used_vars(c) return result + def local_defs(lhs): '''Utility function to return set of assigned/defined variables in the lhs of an assignment statement.''' def get_id(e): @@ -19,6 +21,7 @@ def get_id(e): else: return set([get_id(lhs)]) + def defs(stmt): ''' Return the set of all variables that may be defined (assigned to) in an @@ -38,7 +41,7 @@ def block_defs(block): elif isinstance(stmt, list): return block_defs(stmt) else: - raise ValueError("Unsupported statement type: " + type(stmt).__name__) + raise ValueError("Unsupported statement type: " + type(stmt).__name__) def do_liveness_analysis(fun): @@ -47,7 +50,7 @@ def do_liveness_analysis(fun): analysis are stored directly with each statement-ast `s` as attributes `s.live_in` and `s.live_out`. ''' - def visit (stmt, live_out): + def visit(stmt, live_out): def visitBlock(block, live_out): for s in reversed(block): live_out = visit(s, live_out) @@ -61,9 +64,9 @@ def visitBlock(block, live_out): live2 = visitBlock(stmt.orelse, live_out) return (live1 | live2) | used_vars(stmt.test) elif isinstance(stmt, ast.For): - return live_out # TODO + return live_out # TODO else: - raise ValueError("Unsupported statement type: " + type(stmt).__name__) + raise ValueError("Unsupported statement type: " + type(stmt).__name__) assert type(fun) == ast.FunctionDef live = set() for s in reversed(fun.body): @@ -73,6 +76,7 @@ def visitBlock(block, live_out): # print(ast.dump(s)) # print("Live-In = ", live) + def exposed_uses(stmts): ''' Return the set of variables that are used before being defined by given block. @@ -82,7 +86,7 @@ def visitBlock(block, live_out): live_out = visit(stmt, live_out) return live_out - def visit (stmt, live_out): + def visit(stmt, live_out): if (isinstance(stmt, ast.Assign)): return (live_out.difference(local_defs(stmt.targets[0]))) | used_vars(stmt.value) elif (isinstance(stmt, ast.Return)): @@ -94,4 +98,4 @@ def visit (stmt, live_out): else: raise ValueError("Unsupported statement type: " + type(stmt).__name__) - return visitBlock(stmts, set()) \ No newline at end of file + return visitBlock(stmts, set()) diff --git a/onnxscript/convert.py b/onnxscript/convert.py index b33a98580d..dd1410825e 100644 --- a/onnxscript/convert.py +++ b/onnxscript/convert.py @@ -5,4 +5,4 @@ if __name__ == '__main__': for i in range(1, len(sys.argv)): - convert(sys.argv[i]) \ No newline at end of file + convert(sys.argv[i]) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 766bd45599..85acc9b97d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -19,30 +19,37 @@ # Python-to-IR converter: + def not_allowed(construct): return construct + "not supported." + class TranslationError(Exception): def __init__(self, *args: object) -> None: super().__init__(*args) + def warn(msg): - print ("Warning: " + msg) + print("Warning: " + msg) + def fail(msg): raise TranslationError(msg) + def fail_if(cond, msg): if cond: raise TranslationError(msg) -def ignore (cond, msg): + +def ignore(cond, msg): if cond: - warn (msg) + warn(msg) # Utility to convert a python value to TensorProto: -def pyvalue_to_tensor(tensor_name : str, pyvalue): + +def pyvalue_to_tensor(tensor_name: str, pyvalue): if isinstance(pyvalue, bool): return helper.make_tensor(tensor_name, onnx.TensorProto.BOOL, [], [int(pyvalue)]) elif isinstance(pyvalue, int): @@ -53,6 +60,7 @@ def pyvalue_to_tensor(tensor_name : str, pyvalue): # TODO: str, sequences of values fail("Unimplemented") + # map from python operators to ONNX ops primop_map = { ast.Add: "Add", @@ -63,17 +71,19 @@ def pyvalue_to_tensor(tensor_name : str, pyvalue): ast.Lt: "Less", ast.Gt: "Greater", ast.LtE: "LessOrEqual", - ast.GtE : "GreaterOrEqual", - ast.MatMult : "MatMul", - ast.Mod : "Mod", - ast.Pow : "Pow" + ast.GtE: "GreaterOrEqual", + ast.MatMult: "MatMul", + ast.Mod: "Mod", + ast.Pow: "Pow" } + class Converter: - def __init__(self, ir_builder = IRBuilder()): + def __init__(self, ir_builder=IRBuilder()): self.ir_builder = ir_builder - self.known_modules = { 'onnxscript' : onnxscript, 'onnxscript.types' : types, 'onnx.opset15' : values.opset15 } - self.globals = { "int" : int, "float" : float, "str" : str, "onnx" : values.opset15, "Onnx" : values.opset15, "msdomain" : values.msdomain1 } # 'os' : onnxscript + self.known_modules = {'onnxscript': onnxscript, 'onnxscript.types': types, 'onnx.opset15': values.opset15} + self.globals = {"int": int, "float": float, "str": str, "onnx": values.opset15, + "Onnx": values.opset15, "msdomain": values.msdomain1} # 'os' : onnxscript self.pure_modules = ["onnxscript"] self.default_type = types.FLOAT[...] @@ -86,30 +96,32 @@ def init_function_translation(self): self.locals = [{}] def enter_scope(self, name): - self.outer.insert(0,self.current_fn) + self.outer.insert(0, self.current_fn) self.current_fn = self.ir_builder.new_function(name) self.locals.insert(0, {}) - + def exit_scope(self): graph = self.current_fn self.current_fn = self.outer[0] self.outer.pop(0) self.locals.pop(0) return graph - + def current_scope(self): return self.locals[0] def bind(self, name, val): self.locals[0][name] = val - + def lookup(self, name): for scope in self.locals: - if (name in scope): return scope[name] - if (name in self.globals): return self.globals[name] + if (name in scope): + return scope[name] + if (name in self.globals): + return self.globals[name] raise ValueError("Unbound name: " + name) - def generate_unique_name(self, candidate = "tmp"): + def generate_unique_name(self, candidate="tmp"): r = candidate while (r in self.used_vars): r = candidate + "_" + str(self.nextvar) @@ -117,19 +129,19 @@ def generate_unique_name(self, candidate = "tmp"): self.used_vars.add(r) return r - def to_onnx_attr_ref(self, val : AttrRef): + def to_onnx_attr_ref(self, val: AttrRef): pytype = val.type attrname = "value_float" if (pytype is float) else ("value_int" if (pytype is int) else "value_string") return self.ir_builder.attr_ref(attrname, val.value, pytype) - def to_onnx_var(self, val, target = None): + def to_onnx_var(self, val, target=None): if (isinstance(val, AttrRef)): # promote attribute to value result = self.generate_unique_name(target if target else "tmp") attr = self.to_onnx_attr_ref(val) - self.emit ([result], Op("", "Constant"), [], [attr]) + self.emit([result], Op("", "Constant"), [], [attr]) return result - elif (isinstance(val, ConstValue) and isinstance(val.value, float)): # TODO + elif (isinstance(val, ConstValue) and isinstance(val.value, float)): # TODO result = self.generate_unique_name(target if target else "tmp") return self.emit_const(val.value, result) elif isinstance(val, Dynamic): @@ -138,7 +150,7 @@ def to_onnx_var(self, val, target = None): fail(f"Cannot convert to onnx variable") def py_var_to_onnx_var(self, py_var): return self.to_onnx_var(self.lookup(py_var)) - + def emit(self, outputs, callee, inputs, attrs): self.ir_builder.add_stmt(self.current_fn, outputs, callee.opset, callee.opname, inputs, attrs) @@ -147,13 +159,13 @@ def rename(x): r = self.generate_unique_name(x) self.bind(x, Dynamic(r)) return r - onnx_inputs = inputs # [ self.to_onnx_var(self.lookup(pvar)) for pvar in inputs ] - onnx_outputs = [ rename(x) for x in outputs ] + onnx_inputs = inputs # [ self.to_onnx_var(self.lookup(pvar)) for pvar in inputs ] + onnx_outputs = [rename(x) for x in outputs] self.emit(onnx_outputs, Op("", callee), onnx_inputs, attrs) - def emit_const (self, pyvalue, suggested_name): + def emit_const(self, pyvalue, suggested_name): ovar = self.generate_unique_name(suggested_name) - tensor = pyvalue_to_tensor (ovar, pyvalue) + tensor = pyvalue_to_tensor(ovar, pyvalue) attr = self.ir_builder.attr("value", tensor) self.emit([ovar], Op("", "Constant"), [], [attr]) return ovar @@ -161,7 +173,7 @@ def emit_const (self, pyvalue, suggested_name): def is_pure_module(self, m): return (m in self.pure_modules) - def is_constant_expr (self, node): + def is_constant_expr(self, node): if (isinstance(node, ast.Name)): val = self.lookup(node.id) return isinstance(val, ConstValue) and self.is_pure_module(val.value) @@ -171,7 +183,7 @@ def is_constant_expr (self, node): def eval_constant_expr(self, node): # TODO: assert (self.is_constant_expr(node)) - locals = {} # TODO + locals = {} # TODO return eval(compile(ast.Expression(node), filename="", mode="eval"), self.globals, locals) def eval_attr(self, node): @@ -209,13 +221,13 @@ def translate_expr(self, node, target="tmp"): if (isinstance(node, ast.Call)): r = self.translate_call_expr(node) elif (isinstance(node, ast.BinOp)): - r = self.translate_bin_op_expr (node) + r = self.translate_bin_op_expr(node) elif (isinstance(node, ast.UnaryOp)): - r = self.translate_unary_op_expr (node) + r = self.translate_unary_op_expr(node) elif (isinstance(node, ast.Compare)): - r = self.translate_compare_expr (node) + r = self.translate_compare_expr(node) elif (isinstance(node, ast.Name)): - r = self.translate_name_expr (node) + r = self.translate_name_expr(node) elif (isinstance(node, ast.Num)): r = self.emit_const(node.n, target) elif isinstance(node, ast.NameConstant): @@ -233,7 +245,7 @@ def translate_expr(self, node, target="tmp"): results = [self.generate_unique_name(x) for x in target] callee, args, attrs = r self.emit(results, callee, args, attrs) - return results + return results return r def translate_call_expr(self, node): @@ -257,26 +269,26 @@ def translate_unary_op_expr(self, node): opname = primop_map[op] operand = self.translate_expr(node.operand) return (Op("", opname), [operand], []) - + def translate_compare_expr(self, node): - assert (len(node.ops) == 1) # TODO: handle multiple comparisons in one expression + assert (len(node.ops) == 1) # TODO: handle multiple comparisons in one expression assert (len(node.comparators) == 1) op = type(node.ops[0]) assert (op in primop_map) opname = primop_map[op] left = self.translate_expr(node.left) - right = self.translate_expr(node.comparators[0]) + right = self.translate_expr(node.comparators[0]) return (Op("", opname), [left, right], []) def translate_name_expr(self, node): return self.py_var_to_onnx_var(node.id) - def translate_opset_expr(self, node) -> values.Opset : + def translate_opset_expr(self, node) -> values.Opset: """Return an Opset""" if isinstance(node, ast.Name): try: val = self.lookup(node.id) - if isinstance(val, ConstValue): #TODO + if isinstance(val, ConstValue): # TODO val = val.value if isinstance(val, values.Opset): return val @@ -284,17 +296,18 @@ def translate_opset_expr(self, node) -> values.Opset : except: warn(f"Unknown opset name {node.id}.") return values.Opset(node.id, 1) - elif isinstance (node, ast.Attribute): - fail("Nested module unimplemented") # TODO + elif isinstance(node, ast.Attribute): + fail("Nested module unimplemented") # TODO else: fail("Invalid opset expression.") - def translate_callee_expr(self, node) -> values.Op : + def translate_callee_expr(self, node) -> values.Op: """Return an Op""" if isinstance(node, ast.Attribute): module = self.translate_opset_expr(node.value) opname = node.attr - if (opname not in module): warn (f"{opname} is not a known op in {str(module)}") + if (opname not in module): + warn(f"{opname} is not a known op in {str(module)}") return Op(module, node.attr) elif isinstance(node, ast.Name): try: @@ -305,7 +318,7 @@ def translate_callee_expr(self, node) -> values.Op : warn(f"Unknown function name {node.id}.") return Op(default_opset, node.id) else: - fail ("Invalid callee") + fail("Invalid callee") # Statement translation: A single Python statement is mapped into a sequence of IR statements. @@ -324,7 +337,7 @@ def translate_stmt(self, node): def translate_assign_stmt(self, stmt: ast.Assign): def assign(lhs, rhs): if (isinstance(lhs, ast.Name)): - lhs = lhs.id + lhs = lhs.id if (self.is_constant_expr(rhs)): self.bind(lhs, ConstValue(self.eval_constant_expr(rhs))) else: @@ -346,7 +359,7 @@ def id(x): if (isinstance(rhs, ast.Tuple)): assert isinstance(lhs, ast.Tuple) assert len(lhs.elts) == len(rhs.elts), "Expected same number of elements on lhs and rhs of assignments." - for l,r in zip(lhs.elts, rhs.elts): + for l, r in zip(lhs.elts, rhs.elts): assign(l, r) else: assign(lhs, rhs) @@ -366,7 +379,7 @@ def ret(exp, suffix=""): val = stmt.value assert (val != None), "Return statement without return-value not supported." if (isinstance(val, ast.Tuple)): - return [ret(exp,str(i)) for i,exp in enumerate(val.elts)] + return [ret(exp, str(i)) for i, exp in enumerate(val.elts)] else: return ret(val) @@ -378,11 +391,12 @@ def translate_if_stmt(self, stmt: ast.If): thenAttr = self.ir_builder.attr("then_branch", thenGraph) elseGraph = self.translate_block(stmt.orelse, "elseGraph", live_defs) elseAttr = self.ir_builder.attr("else_branch", elseGraph) + def rename(x): r = self.generate_unique_name(x) self.bind(x, Dynamic(r)) return r - renamed = [ rename(x) for x in live_defs ] + renamed = [rename(x) for x in live_defs] self.emit(renamed, Op("", "If"), [test], [thenAttr, elseAttr]) def translate_for_stmt(self, for_stmt: ast.For): @@ -401,7 +415,7 @@ def translate_for_stmt(self, for_stmt: ast.For): exposed_uses = analysis.exposed_uses(for_stmt.body) vars_def_in_loop = analysis.defs(for_stmt.body) loop_state_vars = vars_def_in_loop.intersection(exposed_uses | for_stmt.live_out) - scan_outputs = set() # TODO + scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) # loop-condition: @@ -417,8 +431,8 @@ def translate_for_stmt(self, for_stmt: ast.For): self.ir_builder.add_input(self.current_fn, o_cond_var, types.BOOL) for pv in loop_state_vars: ov = self.generate_unique_name(pv) - self.ir_builder.add_input(self.current_fn, ov, self.default_type) - self.bind(pv, Dynamic(ov)) + self.ir_builder.add_input(self.current_fn, ov, self.default_type) + self.bind(pv, Dynamic(ov)) for s in for_stmt.body: self.translate_stmt(s) o_cond_out = self.generate_unique_name("cond_out") @@ -426,7 +440,7 @@ def translate_for_stmt(self, for_stmt: ast.For): self.ir_builder.add_output(self.current_fn, o_cond_out, types.BOOL) for pv in loop_state_vars: ov = self.py_var_to_onnx_var(pv) - self.ir_builder.add_output(self.current_fn, ov, self.default_type) # TODO: type + self.ir_builder.add_output(self.current_fn, ov, self.default_type) # TODO: type body = self.exit_scope() inputs = [o_loop_bound, o_true] + [self.py_var_to_onnx_var(pv) for pv in loop_state_vars] @@ -442,28 +456,28 @@ def translate_block(self, stmts, name, live_defs): if (pvar in self.current_scope()): pv_val = self.current_scope()[pvar] output = self.to_onnx_var(pv_val, pvar) - self.ir_builder.add_output(self.current_fn, output, self.default_type) # TODO: need type! + self.ir_builder.add_output(self.current_fn, output, self.default_type) # TODO: need type! else: pv_val = None - for scope in self.locals: # TODO: skip current_scope + for scope in self.locals: # TODO: skip current_scope if (pvar in scope): pv_val = scope[pvar] break if (pv_val is None): - fail (f"Variable {pvar} is not assigned a value along a conditional branch.") + fail(f"Variable {pvar} is not assigned a value along a conditional branch.") # introduce a copy ovar = self.generate_unique_name(pvar) self.emit([ovar], Op("", "Identity"), [self.to_onnx_var(pv_val, pvar)], []) - self.ir_builder.add_output(self.current_fn, ovar, self.default_type) # TODO: need type! + self.ir_builder.add_output(self.current_fn, ovar, self.default_type) # TODO: need type! graph = self.exit_scope() return graph.to_graph_proto() def translate_function_def(self, fn: ast.FunctionDef): args = fn.args if (args.defaults): - warn (f"{fn.name}: Default values not yet implemented.") + warn(f"{fn.name}: Default values not yet implemented.") if (args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg): - warn (f"{fn.name}: Unsupported feature in function signature.") + warn(f"{fn.name}: Unsupported feature in function signature.") self.current_fn = self.ir_builder.new_function(fn.name) for x in args.args: if x.annotation: @@ -473,7 +487,7 @@ def translate_function_def(self, fn: ast.FunctionDef): assert ta.is_valid(typeinfo) if (ta.is_attr(typeinfo)): self.ir_builder.add_attr(self.current_fn, x.arg, typeinfo) - self.bind(x.arg, AttrRef(x.arg, typeinfo)) + self.bind(x.arg, AttrRef(x.arg, typeinfo)) else: self.ir_builder.add_input(self.current_fn, x.arg, typeinfo) self.bind(x.arg, Dynamic(x.arg)) @@ -496,12 +510,12 @@ def translate_function_def(self, fn: ast.FunctionDef): def do_import(self, alias): if print_flag: - print (f"Importing {alias.name} as {alias.asname}") - fail_if (alias.name not in self.known_modules, f"Import: unsupported module {alias.name}") + print(f"Importing {alias.name} as {alias.asname}") + fail_if(alias.name not in self.known_modules, f"Import: unsupported module {alias.name}") asname = alias.asname if alias.asname else alias.name self.globals[asname] = self.known_modules[alias.name] - def top_level_stmt (self, stmt): + def top_level_stmt(self, stmt): if isinstance(stmt, ast.FunctionDef): self.init_function_translation() analysis.do_liveness_analysis(stmt) @@ -516,7 +530,7 @@ def top_level_stmt (self, stmt): self.do_import(alias) elif isinstance(stmt, ast.ImportFrom): fail_if(stmt.module is None, "Import: module unspecified.") - fail_if (stmt.module not in self.known_modules, f"Import: unsupported module {stmt.module}") + fail_if(stmt.module not in self.known_modules, f"Import: unsupported module {stmt.module}") module = self.known_modules[stmt.module] for alias in stmt.names: asname = alias.asname if alias.asname else alias.name @@ -532,7 +546,7 @@ def convert_source(self, src): def convert_file(self, filename): src = open(filename).read() - return self.convert_source(src) + return self.convert_source(src) def convert(self, f): if (isinstance(f, str)): @@ -542,10 +556,11 @@ def convert(self, f): return self.convert_source(f) elif (inspect.isfunction(f)): src = inspect.getsource(f) - return self.convert_source(src) + return self.convert_source(src) else: - fail("Unknown type of input to converter.") + fail("Unknown type of input to converter.") + -def convert (script): +def convert(script): converter = Converter() converter.convert(script) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 60dfd75469..9d94d70554 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -4,9 +4,11 @@ # A simple IR (Function, Stmt, Attr, Var): + def format(list, prefix, sep, suffix): return prefix + sep.join([str(x) for x in list]) + suffix + class Type: def __init__(self) -> None: # TODO @@ -14,24 +16,25 @@ def __init__(self) -> None: tp.tensor_type.elem_type = onnx.TensorProto.FLOAT self.onnx_type = tp # helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [10]) - + def to_type_proto(self): return self.onnx_type - + def __str__(self) -> str: return "SomeType" + class Var: - def __init__(self, varname, type = None) -> None: + def __init__(self, varname, type=None) -> None: self.name = varname self.type = type def __str__(self): return self.name - + def typed_str(self): return self.name + " : " + str(self.type) - + def to_value_info(self): tp = self.type.to_type_proto() # if (not tp.tensor_type.HasField('shape')): @@ -39,14 +42,16 @@ def to_value_info(self): # tp = helper.make_tensor_type_proto(tp.tensor_type.elem_type, [10]) return helper.make_value_info(self.name, tp) + class Attr: def __init__(self, attrproto) -> None: self.attr_proto = attrproto - + def __str__(self): if (self.attr_proto.HasField("ref_attr_name")): return self.attr_proto.name + " = @" + self.attr_proto.ref_attr_name - return helper.printable_attribute(self.attr_proto) # self.name + " = " + self.value + return helper.printable_attribute(self.attr_proto) # self.name + " = " + self.value + class Stmt: def __init__(self, result, module, opname, args, attrs) -> None: @@ -55,7 +60,7 @@ def __init__(self, result, module, opname, args, attrs) -> None: self.opname = opname self.args = args self.attrs = attrs - + def __str__(self): if (isinstance(self.result, str)): print('Ooops') @@ -64,22 +69,23 @@ def __str__(self): if (self.attrs): attrs = format(self.attrs, "<", ", ", ">") - args = format (self.args, "(", ", ", ")") + args = format(self.args, "(", ", ", ")") module = str(self.module) - callee = module + "." + self.opname if (module != '') else self.opname + callee = module + "." + self.opname if (module != '') else self.opname return (lhs + " = " + callee + " " + attrs + args) - + def print(self): - print (str(self)) - + print(str(self)) + def to_node_proto(self): n = helper.make_node(self.opname, - [str(x) for x in self.args], - [str(x) for x in self.result]) + [str(x) for x in self.args], + [str(x) for x in self.result]) for a in self.attrs: n.attribute.append(a.attr_proto) return n + class Function: def __init__(self, name) -> None: self.name = name @@ -89,29 +95,29 @@ def __init__(self, name) -> None: self.attrs = [] def __str__(self): - attrs = format (self.attrs, "<", ", ", ">") if self.attrs else "" - inputs = format ([x.typed_str() for x in self.inputs], "(", ", ", ")") - outputs = format ([x.typed_str() for x in self.outputs], "(", ", ", ")") - stmts = format (self.stmts, "\n{\n ", "\n ", "\n}\n") + attrs = format(self.attrs, "<", ", ", ">") if self.attrs else "" + inputs = format([x.typed_str() for x in self.inputs], "(", ", ", ")") + outputs = format([x.typed_str() for x in self.outputs], "(", ", ", ")") + stmts = format(self.stmts, "\n{\n ", "\n ", "\n}\n") return (self.name + " " + attrs + inputs + " => " + outputs + stmts) def print(self): - print (str(self)) + print(str(self)) for s in self.stmts: for attr in s.attrs: if attr.attr_proto.HasField("g"): print(helper.printable_graph(attr.attr_proto.g)) - def to_graph_proto(self): return helper.make_graph([s.to_node_proto() for s in self.stmts], - self.name, - [x.to_value_info() for x in self.inputs], - [y.to_value_info() for y in self.outputs] - ) + self.name, + [x.to_value_info() for x in self.inputs], + [y.to_value_info() for y in self.outputs] + ) # IRBuilder: abstracts out details of the IR in the python-to-IR converter + class IRBuilder: def new_function(self, name): return Function(name) @@ -131,16 +137,16 @@ def add_attr(self, fn, varname, type): def add_output(self, fn, varname, type): v = Var(varname, type) fn.outputs.append(v) - + def attr(self, attrname, attrval): if (isinstance(attrval, Function)): - attrval = str(attrval) # TODO + attrval = str(attrval) # TODO return Attr(helper.make_attribute(attrname, attrval)) - + def attr_ref(self, attrname, refname, pytype): a = onnx.AttributeProto() a.name = attrname a.ref_attr_name = refname - a.type = ta.pytype_to_attrtype_map[pytype] # onnx.AttributeProto.FLOAT + a.type = ta.pytype_to_attrtype_map[pytype] # onnx.AttributeProto.FLOAT return Attr(a) # TODO: attr_type? diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index e19eec2987..bb051a02ff 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -3,18 +3,21 @@ import onnxscript pytype_to_attrtype_map = { - float : onnx.AttributeProto.FLOAT, - int : onnx.AttributeProto.INT, - str : onnx.AttributeProto.STRING, + float: onnx.AttributeProto.FLOAT, + int: onnx.AttributeProto.INT, + str: onnx.AttributeProto.STRING, } + def is_attr(typeinfo): return typeinfo in {float, int, str} # (typeinfo is float) or (typeinfo is str) or (typeinfo is int) + def is_tensor(typeinfo): return hasattr(typeinfo, "to_type_proto") # return isinstance(typeinfo, onnxscript.Tensor) # TODO + def is_valid(typeinfo): - return is_attr(typeinfo) or is_tensor(typeinfo) \ No newline at end of file + return is_attr(typeinfo) or is_tensor(typeinfo) diff --git a/onnxscript/utils.py b/onnxscript/utils.py index b1fda109dc..5cf0b65a1e 100644 --- a/onnxscript/utils.py +++ b/onnxscript/utils.py @@ -3,23 +3,27 @@ # TODO: enable invocation of ORT kernels + class Model: def __init__(self, onnxfile) -> None: self.session = ort.InferenceSession(onnxfile) - + def __call__(self, *args: Any, **kwds: Any) -> Any: inputs = self.session.get_inputs() for i, arg in enumerate(args): kwds[inputs[i].name] = arg return self.session.run(None, kwds) + def make_ort_value(v): pass + def make_attrs(**kwds): pass -def call_ort (opname, *args, **kwds): + +def call_ort(opname, *args, **kwds): model = Model("todo") ort_args = [make_ort_value(x) for x in args] return model(ort_args) diff --git a/onnxscript/values.py b/onnxscript/values.py index 708637a297..309f2eddc0 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -3,17 +3,18 @@ # ONNX opsets (correspond to python modules in reference-mode) # Have a domain-name, version, and a list of ops + class Opset: def __init__(self, domain, version) -> None: self.domain = domain self.version = version - - def __getitem__ (self, opname): - return onnx.defs.get_schema (opname, self.version, self.domain) - - def __contains__ (self, opname): + + def __getitem__(self, opname): + return onnx.defs.get_schema(opname, self.version, self.domain) + + def __contains__(self, opname): try: - onnx.defs.get_schema (opname, self.version, self.domain) + onnx.defs.get_schema(opname, self.version, self.domain) return True except: return False @@ -21,20 +22,22 @@ def __contains__ (self, opname): def __str__(self) -> str: return self.domain + opset15 = Opset("", 15) msdomain1 = Opset("com.microsoft", 1) # ONNX ops + class Op: def __init__(self, opset, opname) -> None: self.opset = opset self.opname = opname - + def get_schema(self): return self.opset[self.opname] - + def has_schema(self): return (self.opname in self.opset) @@ -43,19 +46,23 @@ def has_schema(self): # AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes # Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs + class Value: def __init__(self, val) -> None: self.value = val -class ConstValue(Value): + +class ConstValue(Value): def __init__(self, val) -> None: super().__init__(val) + class AttrRef(Value): def __init__(self, name: str, type) -> None: super().__init__(name) self.type = type + class Dynamic(Value): def __init__(self, val) -> None: super().__init__(val) From d2b80208bdc384a3f3da9b11ccf16c88bfb4886c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Mar 2022 10:42:59 -0800 Subject: [PATCH 6/6] Python formatting Signed-off-by: Ganesan Ramalingam --- onnxscript/onnxscript/__init__.py | 2 + onnxscript/onnxscript/types/__init__.py | 49 +++++++++++++------------ onnxscript/test/converter_test.py | 11 ++++-- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/onnxscript/onnxscript/__init__.py b/onnxscript/onnxscript/__init__.py index 6fef22f00f..c13d808e36 100644 --- a/onnxscript/onnxscript/__init__.py +++ b/onnxscript/onnxscript/__init__.py @@ -6,6 +6,8 @@ # ElementType: an enumeration encoding the allowed element types in a tensor # Corresponds to TensorProto::DataType + + class ElementType: UNDEFINED = 0 diff --git a/onnxscript/onnxscript/types/__init__.py b/onnxscript/onnxscript/types/__init__.py index cf5f00003d..9116ff5968 100644 --- a/onnxscript/onnxscript/types/__init__.py +++ b/onnxscript/onnxscript/types/__init__.py @@ -1,17 +1,18 @@ import onnx import onnx.helper + class Tensor: # Reference implementation placeholder # represents a generic ONNX tensor type - def __init__(self, dtype = onnx.TensorProto.UNDEFINED, shape = None) -> None: + def __init__(self, dtype=onnx.TensorProto.UNDEFINED, shape=None) -> None: self.dtype = dtype self.shape = shape - + def __str__(self) -> str: shapestr = str(self.shape) if self.shape else "[...]" return onnx.TensorProto.DataType.Name(self.dtype) + shapestr - + def to_type_proto(self): # TODO: handle None return onnx.helper.make_tensor_type_proto(self.dtype, self.shape) @@ -23,11 +24,12 @@ def to_type_proto(self): # x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names) # x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions) + class ParametricTensor: def __init__(self, dtype) -> None: self.dtype = dtype - - def __getitem__ (self, shape): + + def __getitem__(self, shape): def mk_dim(dim): r = onnx.TensorShapeProto.Dimension() if (isinstance(dim, int)): @@ -44,26 +46,27 @@ def mk_dim(dim): else: s = [shape] return Tensor(self.dtype, s) - + def to_type_proto(self): return onnx.helper.make_tensor_type_proto(self.dtype, ()) - + def __str__(self) -> str: return onnx.TensorProto.DataType.Name(self.dtype) -FLOAT = ParametricTensor (onnx.TensorProto.FLOAT) -UINT8 = ParametricTensor (onnx.TensorProto.UINT8) -INT8 = ParametricTensor (onnx.TensorProto.INT8) -UINT16 = ParametricTensor (onnx.TensorProto.UINT16) -INT16 = ParametricTensor (onnx.TensorProto.INT16) -INT32 = ParametricTensor (onnx.TensorProto.INT32) -INT64 = ParametricTensor (onnx.TensorProto.INT64) -STRING = ParametricTensor (onnx.TensorProto.STRING) -BOOL = ParametricTensor (onnx.TensorProto.BOOL) -FLOAT16 = ParametricTensor (onnx.TensorProto.FLOAT16) -DOUBLE = ParametricTensor (onnx.TensorProto.DOUBLE) -UINT32 = ParametricTensor (onnx.TensorProto.UINT32) -UINT64 = ParametricTensor (onnx.TensorProto.UINT64) -COMPLEX64 = ParametricTensor (onnx.TensorProto.COMPLEX64) -COMPLEX128 = ParametricTensor (onnx.TensorProto.COMPLEX128) -BFLOAT16 = ParametricTensor (onnx.TensorProto.BFLOAT16) \ No newline at end of file + +FLOAT = ParametricTensor(onnx.TensorProto.FLOAT) +UINT8 = ParametricTensor(onnx.TensorProto.UINT8) +INT8 = ParametricTensor(onnx.TensorProto.INT8) +UINT16 = ParametricTensor(onnx.TensorProto.UINT16) +INT16 = ParametricTensor(onnx.TensorProto.INT16) +INT32 = ParametricTensor(onnx.TensorProto.INT32) +INT64 = ParametricTensor(onnx.TensorProto.INT64) +STRING = ParametricTensor(onnx.TensorProto.STRING) +BOOL = ParametricTensor(onnx.TensorProto.BOOL) +FLOAT16 = ParametricTensor(onnx.TensorProto.FLOAT16) +DOUBLE = ParametricTensor(onnx.TensorProto.DOUBLE) +UINT32 = ParametricTensor(onnx.TensorProto.UINT32) +UINT64 = ParametricTensor(onnx.TensorProto.UINT64) +COMPLEX64 = ParametricTensor(onnx.TensorProto.COMPLEX64) +COMPLEX128 = ParametricTensor(onnx.TensorProto.COMPLEX128) +BFLOAT16 = ParametricTensor(onnx.TensorProto.BFLOAT16) diff --git a/onnxscript/test/converter_test.py b/onnxscript/test/converter_test.py index 83b758fdd6..c7fd4b6e52 100644 --- a/onnxscript/test/converter_test.py +++ b/onnxscript/test/converter_test.py @@ -5,12 +5,13 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + class TestConverter(unittest.TestCase): - def _convert (self, script): + def _convert(self, script): converter = Converter() converter.convert(script) - def _convert_and_save (self, script): + def _convert_and_save(self, script): converter = Converter() fnlist = converter.convert(script) TEST_OUTPUT_DIR = os.path.join(CURRENT_DIR, "testoutputs") @@ -37,9 +38,10 @@ def foo(x): return msdomain.bar(x, x) """ self._convert(script) + def test_onnxfns(self): self._convert(os.path.join(CURRENT_DIR, "onnxfns.py")) - + def test_models(self): self._convert_and_save(os.path.join(CURRENT_DIR, "onnxmodels.py")) @@ -49,5 +51,6 @@ def test_if_models(self): def test_loop_models(self): self._convert_and_save(os.path.join(CURRENT_DIR, "loop.py")) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()