diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py new file mode 100644 index 0000000000..b9c8e6e2b6 --- /dev/null +++ b/onnxscript/analysis.py @@ -0,0 +1,101 @@ +import ast + + +def used_vars(expr): + ''' Return set of all variables used with an expression.''' + if (isinstance(expr, ast.Name)): + return set([expr.id]) + result = set() + for c in ast.iter_child_nodes(expr): + result = result | used_vars(c) + return result + + +def local_defs(lhs): + '''Utility function to return set of assigned/defined variables in the lhs of an assignment statement.''' + def get_id(e): + assert isinstance(e, ast.Name), "Only simple assignments supported." + return e.id + if (isinstance(lhs, ast.Tuple)): + return set([get_id(x) for x in lhs.elts]) + else: + return set([get_id(lhs)]) + + +def defs(stmt): + ''' + Return the set of all variables that may be defined (assigned to) in an + execution of input stmt. + ''' + def block_defs(block): + result = set() + for s in block: + result = result | defs(s) + return result + if (isinstance(stmt, ast.Assign)): + return local_defs(stmt.targets[0]) + elif (isinstance(stmt, ast.Return)): + return set() + elif (isinstance(stmt, ast.If)): + return block_defs(stmt.body) | block_defs(stmt.orelse) + elif isinstance(stmt, list): + return block_defs(stmt) + else: + raise ValueError("Unsupported statement type: " + type(stmt).__name__) + + +def do_liveness_analysis(fun): + ''' + Perform liveness analysis of the given function-ast. The results of the + analysis are stored directly with each statement-ast `s` as attributes `s.live_in` + and `s.live_out`. + ''' + def visit(stmt, live_out): + def visitBlock(block, live_out): + for s in reversed(block): + live_out = visit(s, live_out) + return live_out + if (isinstance(stmt, ast.Assign)): + return (live_out.difference(local_defs(stmt.targets[0]))) | used_vars(stmt.value) + elif (isinstance(stmt, ast.Return)): + return used_vars(stmt.value) + elif (isinstance(stmt, ast.If)): + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return (live1 | live2) | used_vars(stmt.test) + elif isinstance(stmt, ast.For): + return live_out # TODO + else: + raise ValueError("Unsupported statement type: " + type(stmt).__name__) + assert type(fun) == ast.FunctionDef + live = set() + for s in reversed(fun.body): + s.live_out = live + live = visit(s, live) + s.live_in = live + # print(ast.dump(s)) + # print("Live-In = ", live) + + +def exposed_uses(stmts): + ''' + Return the set of variables that are used before being defined by given block. + ''' + def visitBlock(block, live_out): + for stmt in reversed(block): + live_out = visit(stmt, live_out) + return live_out + + def visit(stmt, live_out): + if (isinstance(stmt, ast.Assign)): + return (live_out.difference(local_defs(stmt.targets[0]))) | used_vars(stmt.value) + elif (isinstance(stmt, ast.Return)): + return used_vars(stmt.value) + elif (isinstance(stmt, ast.If)): + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return (live1 | live2) | used_vars(stmt.test) + else: + raise ValueError("Unsupported statement type: " + type(stmt).__name__) + + return visitBlock(stmts, set()) diff --git a/onnxscript/convert.py b/onnxscript/convert.py new file mode 100644 index 0000000000..dd1410825e --- /dev/null +++ b/onnxscript/convert.py @@ -0,0 +1,8 @@ +import sys +from converter import convert + +# command-line utility to invoke converter on a python file + +if __name__ == '__main__': + for i in range(1, len(sys.argv)): + convert(sys.argv[i]) diff --git a/onnxscript/converter.py b/onnxscript/converter.py new file mode 100644 index 0000000000..85acc9b97d --- /dev/null +++ b/onnxscript/converter.py @@ -0,0 +1,566 @@ + +import os +import inspect +import ast + +import onnxscript +import onnxscript.types as types + +from irbuilder import IRBuilder +import analysis +import type_annotation as ta +import values +from values import ConstValue, AttrRef, Dynamic, Op + +import onnx +import onnx.helper as helper + +print_flag = True + +# Python-to-IR converter: + + +def not_allowed(construct): + return construct + "not supported." + + +class TranslationError(Exception): + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +def warn(msg): + print("Warning: " + msg) + + +def fail(msg): + raise TranslationError(msg) + + +def fail_if(cond, msg): + if cond: + raise TranslationError(msg) + + +def ignore(cond, msg): + if cond: + warn(msg) + +# Utility to convert a python value to TensorProto: + + +def pyvalue_to_tensor(tensor_name: str, pyvalue): + if isinstance(pyvalue, bool): + return helper.make_tensor(tensor_name, onnx.TensorProto.BOOL, [], [int(pyvalue)]) + elif isinstance(pyvalue, int): + return helper.make_tensor(tensor_name, onnx.TensorProto.INT64, [], [pyvalue]) + elif isinstance(pyvalue, float): + return helper.make_tensor(tensor_name, onnx.TensorProto.FLOAT, [], [pyvalue]) + else: + # TODO: str, sequences of values + fail("Unimplemented") + + +# map from python operators to ONNX ops +primop_map = { + ast.Add: "Add", + ast.Sub: "Sub", + ast.Mult: "Mul", + ast.Div: "Div", + ast.USub: "Neg", + ast.Lt: "Less", + ast.Gt: "Greater", + ast.LtE: "LessOrEqual", + ast.GtE: "GreaterOrEqual", + ast.MatMult: "MatMul", + ast.Mod: "Mod", + ast.Pow: "Pow" +} + + +class Converter: + def __init__(self, ir_builder=IRBuilder()): + self.ir_builder = ir_builder + self.known_modules = {'onnxscript': onnxscript, 'onnxscript.types': types, 'onnx.opset15': values.opset15} + self.globals = {"int": int, "float": float, "str": str, "onnx": values.opset15, + "Onnx": values.opset15, "msdomain": values.msdomain1} # 'os' : onnxscript + self.pure_modules = ["onnxscript"] + self.default_type = types.FLOAT[...] + + def init_function_translation(self): + """Initialize self for translating a new function.""" + self.outer = [] + self.current_fn = None + self.nextvar = 0 + self.used_vars = set() + self.locals = [{}] + + def enter_scope(self, name): + self.outer.insert(0, self.current_fn) + self.current_fn = self.ir_builder.new_function(name) + self.locals.insert(0, {}) + + def exit_scope(self): + graph = self.current_fn + self.current_fn = self.outer[0] + self.outer.pop(0) + self.locals.pop(0) + return graph + + def current_scope(self): + return self.locals[0] + + def bind(self, name, val): + self.locals[0][name] = val + + def lookup(self, name): + for scope in self.locals: + if (name in scope): + return scope[name] + if (name in self.globals): + return self.globals[name] + raise ValueError("Unbound name: " + name) + + def generate_unique_name(self, candidate="tmp"): + r = candidate + while (r in self.used_vars): + r = candidate + "_" + str(self.nextvar) + self.nextvar = self.nextvar + 1 + self.used_vars.add(r) + return r + + def to_onnx_attr_ref(self, val: AttrRef): + pytype = val.type + attrname = "value_float" if (pytype is float) else ("value_int" if (pytype is int) else "value_string") + return self.ir_builder.attr_ref(attrname, val.value, pytype) + + def to_onnx_var(self, val, target=None): + if (isinstance(val, AttrRef)): + # promote attribute to value + result = self.generate_unique_name(target if target else "tmp") + attr = self.to_onnx_attr_ref(val) + self.emit([result], Op("", "Constant"), [], [attr]) + return result + elif (isinstance(val, ConstValue) and isinstance(val.value, float)): # TODO + result = self.generate_unique_name(target if target else "tmp") + return self.emit_const(val.value, result) + elif isinstance(val, Dynamic): + return val.value + else: + fail(f"Cannot convert to onnx variable") + + def py_var_to_onnx_var(self, py_var): return self.to_onnx_var(self.lookup(py_var)) + + def emit(self, outputs, callee, inputs, attrs): + self.ir_builder.add_stmt(self.current_fn, outputs, callee.opset, callee.opname, inputs, attrs) + + def emit2(self, outputs, callee, inputs, attrs): + def rename(x): + r = self.generate_unique_name(x) + self.bind(x, Dynamic(r)) + return r + onnx_inputs = inputs # [ self.to_onnx_var(self.lookup(pvar)) for pvar in inputs ] + onnx_outputs = [rename(x) for x in outputs] + self.emit(onnx_outputs, Op("", callee), onnx_inputs, attrs) + + def emit_const(self, pyvalue, suggested_name): + ovar = self.generate_unique_name(suggested_name) + tensor = pyvalue_to_tensor(ovar, pyvalue) + attr = self.ir_builder.attr("value", tensor) + self.emit([ovar], Op("", "Constant"), [], [attr]) + return ovar + + def is_pure_module(self, m): + return (m in self.pure_modules) + + def is_constant_expr(self, node): + if (isinstance(node, ast.Name)): + val = self.lookup(node.id) + return isinstance(val, ConstValue) and self.is_pure_module(val.value) + if isinstance(node, (ast.Call, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Num, ast.Str, ast.Attribute)): + return all([self.is_constant_expr(c) for c in ast.iter_child_nodes(node)]) + return False + + def eval_constant_expr(self, node): + # TODO: assert (self.is_constant_expr(node)) + locals = {} # TODO + return eval(compile(ast.Expression(node), filename="", mode="eval"), self.globals, locals) + + def eval_attr(self, node): + if (isinstance(node, ast.Num)): + return node.n + elif (isinstance(node, ast.Str)): + return node.s + elif (isinstance(node, ast.NameConstant)): + if (node.value == True): + return 1 + elif (node.value == False): + return 0 + else: + raise ValueError("Unsupported NameConstant attribute : " + str(node.value)) + elif (isinstance(node, ast.List)): + return [self.eval_attr(x) for x in node.elts] + elif (isinstance(node, (ast.Call, ast.Attribute))): + return self.eval_constant_expr(node) + else: + raise ValueError("Unsupported attribute type: " + type(node).__name__) + + def translate_attr(self, attr_name, node): + if (isinstance(node, ast.Name)): + val = self.lookup(node.id) + if (isinstance(val, AttrRef)): + return self.to_onnx_attr_ref(val) + else: + # TODO: lookup value; if func.def., compile it to Graph; if a constant; etc. + fail("Unimplemented attribute construct") + return self.ir_builder.attr(attr_name, self.eval_attr(node)) + + # Expression-translation generates "IR statements/nodes" that compute the value of + # the expression into a target-variable, and returns the variable that is assigned this value. + def translate_expr(self, node, target="tmp"): + if (isinstance(node, ast.Call)): + r = self.translate_call_expr(node) + elif (isinstance(node, ast.BinOp)): + r = self.translate_bin_op_expr(node) + elif (isinstance(node, ast.UnaryOp)): + r = self.translate_unary_op_expr(node) + elif (isinstance(node, ast.Compare)): + r = self.translate_compare_expr(node) + elif (isinstance(node, ast.Name)): + r = self.translate_name_expr(node) + elif (isinstance(node, ast.Num)): + r = self.emit_const(node.n, target) + elif isinstance(node, ast.NameConstant): + r = self.emit_const(node.value, target) + else: + raise ValueError("Unsupported expression type: " + type(node).__name__) + if (isinstance(r, tuple)): + if isinstance(target, str): + result = self.generate_unique_name(target) + callee, args, attrs = r + self.emit([result], callee, args, attrs) + return result + else: + assert isinstance(target, list) + results = [self.generate_unique_name(x) for x in target] + callee, args, attrs = r + self.emit(results, callee, args, attrs) + return results + return r + + def translate_call_expr(self, node): + # TODO: for now, we map named arguments to attributes, and positional arguments to inputs. + callee = self.translate_callee_expr(node.func) + args = [self.translate_expr(x) for x in node.args] + attrs = [self.translate_attr(x.arg, x.value) for x in node.keywords] + return (callee, args, attrs) + + def translate_bin_op_expr(self, node): + op = type(node.op) + assert (op in primop_map) + opname = primop_map[op] + left = self.translate_expr(node.left) + right = self.translate_expr(node.right) + return (Op("", opname), [left, right], []) + + def translate_unary_op_expr(self, node): + op = type(node.op) + assert (op in primop_map) + opname = primop_map[op] + operand = self.translate_expr(node.operand) + return (Op("", opname), [operand], []) + + def translate_compare_expr(self, node): + assert (len(node.ops) == 1) # TODO: handle multiple comparisons in one expression + assert (len(node.comparators) == 1) + op = type(node.ops[0]) + assert (op in primop_map) + opname = primop_map[op] + left = self.translate_expr(node.left) + right = self.translate_expr(node.comparators[0]) + return (Op("", opname), [left, right], []) + + def translate_name_expr(self, node): + return self.py_var_to_onnx_var(node.id) + + def translate_opset_expr(self, node) -> values.Opset: + """Return an Opset""" + if isinstance(node, ast.Name): + try: + val = self.lookup(node.id) + if isinstance(val, ConstValue): # TODO + val = val.value + if isinstance(val, values.Opset): + return val + fail(f"{node.id} has value of type {type(node.id)} and used as opset.") + except: + warn(f"Unknown opset name {node.id}.") + return values.Opset(node.id, 1) + elif isinstance(node, ast.Attribute): + fail("Nested module unimplemented") # TODO + else: + fail("Invalid opset expression.") + + def translate_callee_expr(self, node) -> values.Op: + """Return an Op""" + if isinstance(node, ast.Attribute): + module = self.translate_opset_expr(node.value) + opname = node.attr + if (opname not in module): + warn(f"{opname} is not a known op in {str(module)}") + return Op(module, node.attr) + elif isinstance(node, ast.Name): + try: + val = self.lookup(node.id) + except: + default_opset = values.opset15 + if (node.id not in default_opset): + warn(f"Unknown function name {node.id}.") + return Op(default_opset, node.id) + else: + fail("Invalid callee") + + # Statement translation: A single Python statement is mapped into a sequence of IR statements. + + def translate_stmt(self, node): + if (isinstance(node, ast.Assign)): + self.translate_assign_stmt(node) + elif (isinstance(node, ast.Return)): + self.translate_return_stmt(node) + elif (isinstance(node, ast.If)): + self.translate_if_stmt(node) + elif isinstance(node, ast.For): + self.translate_for_stmt(node) + else: + raise ValueError("Unsupported statement type: " + type(node).__name__) + + def translate_assign_stmt(self, stmt: ast.Assign): + def assign(lhs, rhs): + if (isinstance(lhs, ast.Name)): + lhs = lhs.id + if (self.is_constant_expr(rhs)): + self.bind(lhs, ConstValue(self.eval_constant_expr(rhs))) + else: + t = self.translate_expr(rhs, lhs) + self.bind(lhs, Dynamic(t)) + elif isinstance(lhs, ast.Tuple): + def id(x): + assert isinstance(x, ast.Name) + return x.id + ids = [id(x) for x in lhs.elts] + onnxids = self.translate_expr(rhs, ids) + for x, y in zip(ids, onnxids): + self.bind(x, Dynamic(y)) + else: + fail("Unsupported construct in LHS of assignment.") + assert (len(stmt.targets) == 1), "Multi-assignment not supported." + lhs = stmt.targets[0] + rhs = stmt.value + if (isinstance(rhs, ast.Tuple)): + assert isinstance(lhs, ast.Tuple) + assert len(lhs.elts) == len(rhs.elts), "Expected same number of elements on lhs and rhs of assignments." + for l, r in zip(lhs.elts, rhs.elts): + assign(l, r) + else: + assign(lhs, rhs) + + def translate_return_stmt(self, stmt: ast.Return): + def ret(exp, suffix=""): + ovar = self.translate_expr(exp, "return_val" + suffix) + # if hasattr(self, returntype) and self.num_outputs < len(self.returntype): + try: + t = self.returntype[self.num_outputs] + except Exception as e: + t = self.default_type + self.ir_builder.add_output(self.current_fn, ovar, t) + self.num_outputs += 1 + return ovar + + val = stmt.value + assert (val != None), "Return statement without return-value not supported." + if (isinstance(val, ast.Tuple)): + return [ret(exp, str(i)) for i, exp in enumerate(val.elts)] + else: + return ret(val) + + def translate_if_stmt(self, stmt: ast.If): + live_defs = list(stmt.live_out.intersection(analysis.defs(stmt))) + # print(live_defs) + test = self.translate_expr(stmt.test, "cond") + thenGraph = self.translate_block(stmt.body, "thenGraph", live_defs) + thenAttr = self.ir_builder.attr("then_branch", thenGraph) + elseGraph = self.translate_block(stmt.orelse, "elseGraph", live_defs) + elseAttr = self.ir_builder.attr("else_branch", elseGraph) + + def rename(x): + r = self.generate_unique_name(x) + self.bind(x, Dynamic(r)) + return r + renamed = [rename(x) for x in live_defs] + self.emit(renamed, Op("", "If"), [test], [thenAttr, elseAttr]) + + def translate_for_stmt(self, for_stmt: ast.For): + # loop-variable + assert isinstance(for_stmt.target, ast.Name), "For loop target must be a single variable." + p_loop_var = for_stmt.target.id + # iter + iter = for_stmt.iter + assert isinstance(iter, ast.Call), "Loop bound not a call." + assert isinstance(iter.func, ast.Name), "Unsupported loop bound." + assert iter.func.id == "range", "Unsupported loop bound." + assert iter.args and len(iter.args) == 1, "Unsupported loop bound." + assert not iter.keywords, "Unsupported loop bound." + o_loop_bound = self.translate_expr(iter.args[0], "loop_bound") + # analyze loop body + exposed_uses = analysis.exposed_uses(for_stmt.body) + vars_def_in_loop = analysis.defs(for_stmt.body) + loop_state_vars = vars_def_in_loop.intersection(exposed_uses | for_stmt.live_out) + scan_outputs = set() # TODO + outputs = list(loop_state_vars | scan_outputs) + + # loop-condition: + o_true = self.emit_const(True, "true") + # o_loop_bound = self.emit_const(3, "loop_bound") + + # build loop_body + self.enter_scope("loop_body") + o_loop_var = self.generate_unique_name(p_loop_var) + self.ir_builder.add_input(self.current_fn, o_loop_var, types.INT64) + self.bind(p_loop_var, Dynamic(o_loop_var)) + o_cond_var = self.generate_unique_name("cond_in") + self.ir_builder.add_input(self.current_fn, o_cond_var, types.BOOL) + for pv in loop_state_vars: + ov = self.generate_unique_name(pv) + self.ir_builder.add_input(self.current_fn, ov, self.default_type) + self.bind(pv, Dynamic(ov)) + for s in for_stmt.body: + self.translate_stmt(s) + o_cond_out = self.generate_unique_name("cond_out") + self.emit([o_cond_out], Op("", "Identity"), [o_cond_var], []) + self.ir_builder.add_output(self.current_fn, o_cond_out, types.BOOL) + for pv in loop_state_vars: + ov = self.py_var_to_onnx_var(pv) + self.ir_builder.add_output(self.current_fn, ov, self.default_type) # TODO: type + body = self.exit_scope() + + inputs = [o_loop_bound, o_true] + [self.py_var_to_onnx_var(pv) for pv in loop_state_vars] + attrs = [self.ir_builder.attr("body", body.to_graph_proto())] + self.emit2(outputs, "Loop", inputs, attrs) + + # Translation of a statement-block to GraphProto attribute + def translate_block(self, stmts, name, live_defs): + self.enter_scope(name) + for s in stmts: + self.translate_stmt(s) + for pvar in live_defs: + if (pvar in self.current_scope()): + pv_val = self.current_scope()[pvar] + output = self.to_onnx_var(pv_val, pvar) + self.ir_builder.add_output(self.current_fn, output, self.default_type) # TODO: need type! + else: + pv_val = None + for scope in self.locals: # TODO: skip current_scope + if (pvar in scope): + pv_val = scope[pvar] + break + if (pv_val is None): + fail(f"Variable {pvar} is not assigned a value along a conditional branch.") + # introduce a copy + ovar = self.generate_unique_name(pvar) + self.emit([ovar], Op("", "Identity"), [self.to_onnx_var(pv_val, pvar)], []) + self.ir_builder.add_output(self.current_fn, ovar, self.default_type) # TODO: need type! + graph = self.exit_scope() + return graph.to_graph_proto() + + def translate_function_def(self, fn: ast.FunctionDef): + args = fn.args + if (args.defaults): + warn(f"{fn.name}: Default values not yet implemented.") + if (args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg): + warn(f"{fn.name}: Unsupported feature in function signature.") + self.current_fn = self.ir_builder.new_function(fn.name) + for x in args.args: + if x.annotation: + typeinfo = self.eval_constant_expr(x.annotation) + else: + typeinfo = self.default_type + assert ta.is_valid(typeinfo) + if (ta.is_attr(typeinfo)): + self.ir_builder.add_attr(self.current_fn, x.arg, typeinfo) + self.bind(x.arg, AttrRef(x.arg, typeinfo)) + else: + self.ir_builder.add_input(self.current_fn, x.arg, typeinfo) + self.bind(x.arg, Dynamic(x.arg)) + if fn.returns: + returntype = self.eval_constant_expr(fn.returns) + if isinstance(returntype, tuple): + assert all([ta.is_valid(t) for t in returntype]) + self.returntype = returntype + else: + assert ta.is_valid(returntype) + self.returntype = (returntype,) + else: + self.returntype = None + self.num_outputs = 0 + for s in fn.body: + self.translate_stmt(s) + if self.returntype is not None: + assert (self.num_outputs == len(self.returntype)), "Mismatch in number of return values and types" + return self.current_fn + + def do_import(self, alias): + if print_flag: + print(f"Importing {alias.name} as {alias.asname}") + fail_if(alias.name not in self.known_modules, f"Import: unsupported module {alias.name}") + asname = alias.asname if alias.asname else alias.name + self.globals[asname] = self.known_modules[alias.name] + + def top_level_stmt(self, stmt): + if isinstance(stmt, ast.FunctionDef): + self.init_function_translation() + analysis.do_liveness_analysis(stmt) + fn_ir = self.translate_function_def(stmt) + if print_flag: + print("============== OUTPUT =============") + fn_ir.print() + print() + return fn_ir + elif isinstance(stmt, ast.Import): + for alias in stmt.names: + self.do_import(alias) + elif isinstance(stmt, ast.ImportFrom): + fail_if(stmt.module is None, "Import: module unspecified.") + fail_if(stmt.module not in self.known_modules, f"Import: unsupported module {stmt.module}") + module = self.known_modules[stmt.module] + for alias in stmt.names: + asname = alias.asname if alias.asname else alias.name + self.globals[asname] = getattr(module, alias.name) + else: + raise ValueError("Unsupported top-level statement type: " + type(stmt).__name__) + + def convert_source(self, src): + module = ast.parse(src) + assert type(module) == ast.Module + converted = [self.top_level_stmt(d) for d in module.body] + return [x for x in converted if x is not None] + + def convert_file(self, filename): + src = open(filename).read() + return self.convert_source(src) + + def convert(self, f): + if (isinstance(f, str)): + if (os.path.exists(f)): + return self.convert_file(f) + else: + return self.convert_source(f) + elif (inspect.isfunction(f)): + src = inspect.getsource(f) + return self.convert_source(src) + else: + fail("Unknown type of input to converter.") + + +def convert(script): + converter = Converter() + converter.convert(script) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py new file mode 100644 index 0000000000..9d94d70554 --- /dev/null +++ b/onnxscript/irbuilder.py @@ -0,0 +1,152 @@ +import onnx +import onnx.helper as helper +import type_annotation as ta + +# A simple IR (Function, Stmt, Attr, Var): + + +def format(list, prefix, sep, suffix): + return prefix + sep.join([str(x) for x in list]) + suffix + + +class Type: + def __init__(self) -> None: + # TODO + tp = onnx.TypeProto() + tp.tensor_type.elem_type = onnx.TensorProto.FLOAT + self.onnx_type = tp + # helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [10]) + + def to_type_proto(self): + return self.onnx_type + + def __str__(self) -> str: + return "SomeType" + + +class Var: + def __init__(self, varname, type=None) -> None: + self.name = varname + self.type = type + + def __str__(self): + return self.name + + def typed_str(self): + return self.name + " : " + str(self.type) + + def to_value_info(self): + tp = self.type.to_type_proto() + # if (not tp.tensor_type.HasField('shape')): + # # TODO: temporary patch to export a function as a graph + # tp = helper.make_tensor_type_proto(tp.tensor_type.elem_type, [10]) + return helper.make_value_info(self.name, tp) + + +class Attr: + def __init__(self, attrproto) -> None: + self.attr_proto = attrproto + + def __str__(self): + if (self.attr_proto.HasField("ref_attr_name")): + return self.attr_proto.name + " = @" + self.attr_proto.ref_attr_name + return helper.printable_attribute(self.attr_proto) # self.name + " = " + self.value + + +class Stmt: + def __init__(self, result, module, opname, args, attrs) -> None: + self.result = result + self.module = module + self.opname = opname + self.args = args + self.attrs = attrs + + def __str__(self): + if (isinstance(self.result, str)): + print('Ooops') + lhs = ", ".join(self.result) + attrs = "" + if (self.attrs): + attrs = format(self.attrs, "<", ", ", ">") + + args = format(self.args, "(", ", ", ")") + module = str(self.module) + callee = module + "." + self.opname if (module != '') else self.opname + return (lhs + " = " + callee + " " + attrs + args) + + def print(self): + print(str(self)) + + def to_node_proto(self): + n = helper.make_node(self.opname, + [str(x) for x in self.args], + [str(x) for x in self.result]) + for a in self.attrs: + n.attribute.append(a.attr_proto) + return n + + +class Function: + def __init__(self, name) -> None: + self.name = name + self.inputs = [] + self.outputs = [] + self.stmts = [] + self.attrs = [] + + def __str__(self): + attrs = format(self.attrs, "<", ", ", ">") if self.attrs else "" + inputs = format([x.typed_str() for x in self.inputs], "(", ", ", ")") + outputs = format([x.typed_str() for x in self.outputs], "(", ", ", ")") + stmts = format(self.stmts, "\n{\n ", "\n ", "\n}\n") + return (self.name + " " + attrs + inputs + " => " + outputs + stmts) + + def print(self): + print(str(self)) + for s in self.stmts: + for attr in s.attrs: + if attr.attr_proto.HasField("g"): + print(helper.printable_graph(attr.attr_proto.g)) + + def to_graph_proto(self): + return helper.make_graph([s.to_node_proto() for s in self.stmts], + self.name, + [x.to_value_info() for x in self.inputs], + [y.to_value_info() for y in self.outputs] + ) + +# IRBuilder: abstracts out details of the IR in the python-to-IR converter + + +class IRBuilder: + def new_function(self, name): + return Function(name) + + def add_stmt(self, fn, results, module, opname, args, attrs): + s = Stmt(results, module, opname, args, attrs) + fn.stmts.append(s) + + def add_input(self, fn, varname, type): + v = Var(varname, type) + fn.inputs.append(v) + + def add_attr(self, fn, varname, type): + v = Var(varname, type) + fn.attrs.append(v) + + def add_output(self, fn, varname, type): + v = Var(varname, type) + fn.outputs.append(v) + + def attr(self, attrname, attrval): + if (isinstance(attrval, Function)): + attrval = str(attrval) # TODO + return Attr(helper.make_attribute(attrname, attrval)) + + def attr_ref(self, attrname, refname, pytype): + a = onnx.AttributeProto() + a.name = attrname + a.ref_attr_name = refname + a.type = ta.pytype_to_attrtype_map[pytype] # onnx.AttributeProto.FLOAT + return Attr(a) + # TODO: attr_type? diff --git a/onnxscript/onnxscript/__init__.py b/onnxscript/onnxscript/__init__.py new file mode 100644 index 0000000000..c13d808e36 --- /dev/null +++ b/onnxscript/onnxscript/__init__.py @@ -0,0 +1,37 @@ +from argparse import ArgumentError +from typing import Union +from onnx.helper import make_tensor +import onnx +import onnx.helper + +# ElementType: an enumeration encoding the allowed element types in a tensor +# Corresponds to TensorProto::DataType + + +class ElementType: + UNDEFINED = 0 + + FLOAT = 1 # float + UINT8 = 2 # uint8_t + INT8 = 3 # int8_t + UINT16 = 4 # uint16_t + INT16 = 5 # int16_t + INT32 = 6 # int32_t + INT64 = 7 # int64_t + STRING = 8 # string + BOOL = 9 # bool + + # IEEE754 half-precision floating-point format (16 bits wide). + # This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10 + + DOUBLE = 11 + UINT32 = 12 + UINT64 = 13 + COMPLEX64 = 14 # complex with float32 real and imaginary components + COMPLEX128 = 15 # complex with float64 real and imaginary components + + # Non-IEEE floating-point format based on IEEE754 single-precision + # floating-point number truncated to 16 bits. + # This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16 diff --git a/onnxscript/onnxscript/types/__init__.py b/onnxscript/onnxscript/types/__init__.py new file mode 100644 index 0000000000..9116ff5968 --- /dev/null +++ b/onnxscript/onnxscript/types/__init__.py @@ -0,0 +1,72 @@ +import onnx +import onnx.helper + + +class Tensor: + # Reference implementation placeholder + # represents a generic ONNX tensor type + def __init__(self, dtype=onnx.TensorProto.UNDEFINED, shape=None) -> None: + self.dtype = dtype + self.shape = shape + + def __str__(self) -> str: + shapestr = str(self.shape) if self.shape else "[...]" + return onnx.TensorProto.DataType.Name(self.dtype) + shapestr + + def to_type_proto(self): + # TODO: handle None + return onnx.helper.make_tensor_type_proto(self.dtype, self.shape) + +# Utilities used to create parametrized type-annotations for tensors. +# Example type annotations: +# x : FLOAT (a scalar-tensor of rank 0) +# x : FLOAT[...] (a tensor of unknown rank) +# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names) +# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions) + + +class ParametricTensor: + def __init__(self, dtype) -> None: + self.dtype = dtype + + def __getitem__(self, shape): + def mk_dim(dim): + r = onnx.TensorShapeProto.Dimension() + if (isinstance(dim, int)): + r.dim_value = dim + elif (isinstance(dim, str)): + r.dim_param = dim + elif (dim is not None): + raise TypeError("Invalid dimension") + return r + if (isinstance(shape, tuple)): + s = shape + elif (shape == Ellipsis): + s = None + else: + s = [shape] + return Tensor(self.dtype, s) + + def to_type_proto(self): + return onnx.helper.make_tensor_type_proto(self.dtype, ()) + + def __str__(self) -> str: + return onnx.TensorProto.DataType.Name(self.dtype) + + +FLOAT = ParametricTensor(onnx.TensorProto.FLOAT) +UINT8 = ParametricTensor(onnx.TensorProto.UINT8) +INT8 = ParametricTensor(onnx.TensorProto.INT8) +UINT16 = ParametricTensor(onnx.TensorProto.UINT16) +INT16 = ParametricTensor(onnx.TensorProto.INT16) +INT32 = ParametricTensor(onnx.TensorProto.INT32) +INT64 = ParametricTensor(onnx.TensorProto.INT64) +STRING = ParametricTensor(onnx.TensorProto.STRING) +BOOL = ParametricTensor(onnx.TensorProto.BOOL) +FLOAT16 = ParametricTensor(onnx.TensorProto.FLOAT16) +DOUBLE = ParametricTensor(onnx.TensorProto.DOUBLE) +UINT32 = ParametricTensor(onnx.TensorProto.UINT32) +UINT64 = ParametricTensor(onnx.TensorProto.UINT64) +COMPLEX64 = ParametricTensor(onnx.TensorProto.COMPLEX64) +COMPLEX128 = ParametricTensor(onnx.TensorProto.COMPLEX128) +BFLOAT16 = ParametricTensor(onnx.TensorProto.BFLOAT16) diff --git a/onnxscript/test/__init__.py b/onnxscript/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/test/attrref.py b/onnxscript/test/attrref.py new file mode 100644 index 0000000000..722d85f055 --- /dev/null +++ b/onnxscript/test/attrref.py @@ -0,0 +1,9 @@ + +def float_attr_ref_test (X, alpha: float) : + return onnx.Foo(X, alpha) + +def int_attr_ref_test (X, alpha: int) : + return onnx.Foo(X, alpha) + +def str_attr_ref_test (X, alpha: str) : + return onnx.Foo(X, alpha) \ No newline at end of file diff --git a/onnxscript/test/converter_test.py b/onnxscript/test/converter_test.py new file mode 100644 index 0000000000..c7fd4b6e52 --- /dev/null +++ b/onnxscript/test/converter_test.py @@ -0,0 +1,56 @@ +import unittest +import os +import onnx +from converter import Converter + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class TestConverter(unittest.TestCase): + def _convert(self, script): + converter = Converter() + converter.convert(script) + + def _convert_and_save(self, script): + converter = Converter() + fnlist = converter.convert(script) + TEST_OUTPUT_DIR = os.path.join(CURRENT_DIR, "testoutputs") + if (not os.path.exists(TEST_OUTPUT_DIR)): + os.makedirs(TEST_OUTPUT_DIR) + for f in fnlist: + graph = f.to_graph_proto() + model = onnx.helper.make_model(graph, producer_name='p2o', opset_imports=[onnx.helper.make_opsetid("", 15)]) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model) + onnx.save(model, os.path.join(TEST_OUTPUT_DIR, f.name + ".onnx")) + + def test_source_input(self): + script = """ +def square(x): + return onnx.Mul(x, x) +""" + self._convert(script) + + def test_msdomain(self): + # Temporary patch to use com.microsoft domain + script = """ +def foo(x): + return msdomain.bar(x, x) +""" + self._convert(script) + + def test_onnxfns(self): + self._convert(os.path.join(CURRENT_DIR, "onnxfns.py")) + + def test_models(self): + self._convert_and_save(os.path.join(CURRENT_DIR, "onnxmodels.py")) + + def test_if_models(self): + self._convert_and_save(os.path.join(CURRENT_DIR, "if.py")) + + def test_loop_models(self): + self._convert_and_save(os.path.join(CURRENT_DIR, "loop.py")) + + +if __name__ == '__main__': + unittest.main() diff --git a/onnxscript/test/dropout.py b/onnxscript/test/dropout.py new file mode 100644 index 0000000000..76135d31fa --- /dev/null +++ b/onnxscript/test/dropout.py @@ -0,0 +1,10 @@ + +def Dropout(data, ratio, training_mode, seed : int): + if (training_mode): + mask = onnx.ConstantOfShape(onnx.Shape(data), value = True) + output = data + else: + rand = onnx.RandomUniformLike(data, dtype=1, seed=seed) + mask = (rand >= ratio) + output = onnx.Where(mask, data, 0) / (1.0 - ratio) + return (output, mask) \ No newline at end of file diff --git a/onnxscript/test/eg1.py b/onnxscript/test/eg1.py new file mode 100644 index 0000000000..eac4258048 --- /dev/null +++ b/onnxscript/test/eg1.py @@ -0,0 +1,27 @@ + +from onnxscript.types import FLOAT + +# tensor inputs can have ONNX-like type annotations +def gemm (A : FLOAT[2048,124], W : FLOAT[124,4096], Bias : FLOAT[4096]) -> FLOAT[2048,4096] : + return onnx.MatMul(A, W) + Bias + +# tensors and attributes distinguished by their types +def scale (A : FLOAT[...], alpha: float, beta: float) -> FLOAT[...] : + return alpha * A + beta + +# can return multiple-values +def prodsum (A: FLOAT['N'], B: FLOAT['N']) -> (FLOAT['N'], FLOAT['N']) : + prod = A * B + sum = A + B + return prod, sum + +# can call ops/functions that return multiple-values +def dropout_eg(A: FLOAT[...]) -> FLOAT[...]: + output, mask = onnx.Dropout(A, 0.7, True, seed=1729) + return output + +# will rename variable assigned multiple times +def renaming(A : FLOAT["N"]) -> FLOAT["N"] : + T = onnx.Abs(A) + T = onnx.Neg(T) + return T \ No newline at end of file diff --git a/onnxscript/test/if.py b/onnxscript/test/if.py new file mode 100644 index 0000000000..902e8df85a --- /dev/null +++ b/onnxscript/test/if.py @@ -0,0 +1,31 @@ +from onnxscript.types import FLOAT + +def maxsum (A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + sum1 = onnx.ReduceSum(A) + sum2 = onnx.ReduceSum(B) + if (sum1 < sum2): + result = onnx.Identity(B) + else: + result = onnx.Identity(A) + return result + +# Test inference of inputs/outputs for then/else blocks: +def maxsum (A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + sum1 = onnx.ReduceSum(A) + sum2 = onnx.ReduceSum(B) + if (sum1 < sum2): + temp = onnx.Identity(B) + result = onnx.Identity(temp) + else: + temp = onnx.Identity(A) + result = onnx.Identity(temp) + return result + +# test variables assigned only in one branch +def maxsum2 (A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + sum1 = onnx.ReduceSum(A) + sum2 = onnx.ReduceSum(B) + result = onnx.Identity(A) + if (sum1 < sum2): + result = onnx.Identity(B) + return result diff --git a/onnxscript/test/loop.py b/onnxscript/test/loop.py new file mode 100644 index 0000000000..89a8e6f512 --- /dev/null +++ b/onnxscript/test/loop.py @@ -0,0 +1,19 @@ +from onnxscript.types import FLOAT, INT64 + +# simple loop +def sumprod (x : FLOAT['N'], N : INT64) -> (FLOAT['N'], FLOAT['N']) : + sum = onnx.Identity(x) + prod = onnx.Identity(x) + for i in range(N): + sum = sum + x + prod = prod * x + return sum, prod + +# loop: loop-bound as an expression +def sumprod2 (x : FLOAT['N'], N : INT64) -> (FLOAT['N'], FLOAT['N']) : + sum = onnx.Identity(x) + prod = onnx.Identity(x) + for i in range(2*N+1): + sum = sum + x + prod = prod * x + return sum, prod \ No newline at end of file diff --git a/onnxscript/test/multi.py b/onnxscript/test/multi.py new file mode 100644 index 0000000000..1967e4e9a2 --- /dev/null +++ b/onnxscript/test/multi.py @@ -0,0 +1,7 @@ + +from onnxscript.types import FLOAT + +def multi (A : FLOAT["N"]) -> FLOAT["N"] : + x, y = onnx.Foo (A) + x, y = onnx.Bar (A) + return x + y \ No newline at end of file diff --git a/onnxscript/test/onnxfns.py b/onnxscript/test/onnxfns.py new file mode 100644 index 0000000000..201442c8b2 --- /dev/null +++ b/onnxscript/test/onnxfns.py @@ -0,0 +1,39 @@ + +def Relu(X): + return onnx.Max (X, 0) + +# TODO: default-values of attributes +def Selu(X, alpha : float = 1.67326319217681884765625, gamma : float = 1.05070102214813232421875): + neg = gamma * (alpha * onnx.Exp(X) - alpha) + pos = gamma * X + return onnx.Where (X <= 0, neg, pos) + +def Elu(X, alpha : float = 1.0): + return onnx.Where (X < 0, alpha * (onnx.Exp(X) - 1.0), X) + +def ThresholdedRelu (X, alpha : float = 1.0): + return onnx.Where (X > alpha, X, 0) + +def LeakyRelu(X, alpha : float = 0.01): + return onnx.Where(X < 0, alpha * X, X) + +def PRelu(X, slope): + # future-work: capturing extra requirements such as: + # slope must be unidirectionally broadcastable to X's shape. + return onnx.Where(X < 0, slope * X, X) + +def HardSigmoid (X, alpha : float = 0.2, beta : float = 0.5): + return onnx.Max(0, onnx.Min(1, alpha * X + beta)) + +def Shrink(x, bias : float = 0.0, lambd : float = 0.5): + return onnx.Where(x < -lambd, x + bias, Onnx.Where(x > lambd, x - bias, 0)) + +def Softplus(X): + return onnx.Log(onnx.Exp(X) + 1) + +def Softsign(X): + return X / (1 + onnx.Abs(X)) + +def Clip(input, min, max): + # TODO: default values specified for min/max + return onnx.Where (input < min, min, onnx.Where(input > max, max, input)) \ No newline at end of file diff --git a/onnxscript/test/onnxmodels.py b/onnxscript/test/onnxmodels.py new file mode 100644 index 0000000000..3c4dd6b02c --- /dev/null +++ b/onnxscript/test/onnxmodels.py @@ -0,0 +1,71 @@ +from onnxscript.types import FLOAT + +# type annotation tests: + +def ta_test1(A : FLOAT[100], B: FLOAT[100]) -> FLOAT[100] : + return A + B + +def ta_test2(A : FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"] : + return A + B + +# feature: constant promotion + +def plus1 (A: FLOAT["N"]) -> FLOAT["N"] : + return A + 1.0 + +# example models + +def eg1 (A : FLOAT[100, 200], W : FLOAT[200, 96], B : FLOAT[96]) -> FLOAT[100, 96]: + return onnx.MatMul(A, W) + B + +def gemmgelu ( + A : FLOAT[2048,16], + W : FLOAT[16,4096], + Bias : FLOAT[4096] + ) -> FLOAT[2048,4096] : + + a = onnx.Constant(value_float = 0.5) + b = onnx.Constant(value_float = 0.797885) + c = onnx.Constant(value_float = 0.035677) + one = onnx.Constant(value_float = 1.0) + P1 = onnx.MatMul(A, W) + X = onnx.Add(P1, Bias) + T1 = onnx.Mul(X, X) + T2 = onnx.Mul(c, T1) + T3 = onnx.Add(b, T2) + T4 = onnx.Mul(X, T3) + T5 = onnx.Tanh(T4) + T6 = onnx.Add(one, T5) + T7 = onnx.Mul(X, T6) + Y = onnx.Mul(a, T7) + return Y + +def gemmgelu2 ( + A : FLOAT[2048,16], + W : FLOAT[16,4096], + Bias : FLOAT[4096] + ) -> FLOAT[2048,4096] : + + a = onnx.Constant(value_float = 0.5) + b = onnx.Constant(value_float = 0.797885) + c = onnx.Constant(value_float = 0.035677) + one = onnx.Constant(value_float = 1.0) + P1 = onnx.MatMul(A, W) + X = onnx.MatMul(A, W) + Bias + T4 = X * (b + c * X * X) + Y = a * X * (onnx.Tanh(T4) + one) + return Y + +def gemmgelu3 ( + A : FLOAT[2048,16], + W : FLOAT[16,4096], + Bias : FLOAT[4096] + ) -> FLOAT[2048,4096] : + a = 0.5 + b = 0.797885 + c = 0.035677 + P1 = onnx.MatMul(A, W) + X = onnx.MatMul(A, W) + Bias + T4 = X * (b + c * X * X) + Y = a * X * (onnx.Tanh(T4) + 1.0) + return Y \ No newline at end of file diff --git a/onnxscript/test/renaming.py b/onnxscript/test/renaming.py new file mode 100644 index 0000000000..c1afcae34c --- /dev/null +++ b/onnxscript/test/renaming.py @@ -0,0 +1,14 @@ +from onnxscript.types import FLOAT + +# same variable assigned multiple times +def renaming(A : FLOAT["N"]) -> FLOAT["N"] : + T = onnx.Abs(A) + T = onnx.Neg(A) + return T + +# clash between generated-name and pre-existing name +def renaming2(A : FLOAT["N"]) -> FLOAT["N"] : + T_0 = onnx.Relu(A) + T = onnx.Abs(A) + T = onnx.Neg(A) + return T \ No newline at end of file diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py new file mode 100644 index 0000000000..bb051a02ff --- /dev/null +++ b/onnxscript/type_annotation.py @@ -0,0 +1,23 @@ + +import onnx +import onnxscript + +pytype_to_attrtype_map = { + float: onnx.AttributeProto.FLOAT, + int: onnx.AttributeProto.INT, + str: onnx.AttributeProto.STRING, +} + + +def is_attr(typeinfo): + return typeinfo in {float, int, str} + # (typeinfo is float) or (typeinfo is str) or (typeinfo is int) + + +def is_tensor(typeinfo): + return hasattr(typeinfo, "to_type_proto") + # return isinstance(typeinfo, onnxscript.Tensor) # TODO + + +def is_valid(typeinfo): + return is_attr(typeinfo) or is_tensor(typeinfo) diff --git a/onnxscript/utils.py b/onnxscript/utils.py new file mode 100644 index 0000000000..5cf0b65a1e --- /dev/null +++ b/onnxscript/utils.py @@ -0,0 +1,29 @@ +from typing import Any +import onnxruntime as ort + +# TODO: enable invocation of ORT kernels + + +class Model: + def __init__(self, onnxfile) -> None: + self.session = ort.InferenceSession(onnxfile) + + def __call__(self, *args: Any, **kwds: Any) -> Any: + inputs = self.session.get_inputs() + for i, arg in enumerate(args): + kwds[inputs[i].name] = arg + return self.session.run(None, kwds) + + +def make_ort_value(v): + pass + + +def make_attrs(**kwds): + pass + + +def call_ort(opname, *args, **kwds): + model = Model("todo") + ort_args = [make_ort_value(x) for x in args] + return model(ort_args) diff --git a/onnxscript/values.py b/onnxscript/values.py new file mode 100644 index 0000000000..309f2eddc0 --- /dev/null +++ b/onnxscript/values.py @@ -0,0 +1,68 @@ +import onnx + +# ONNX opsets (correspond to python modules in reference-mode) +# Have a domain-name, version, and a list of ops + + +class Opset: + def __init__(self, domain, version) -> None: + self.domain = domain + self.version = version + + def __getitem__(self, opname): + return onnx.defs.get_schema(opname, self.version, self.domain) + + def __contains__(self, opname): + try: + onnx.defs.get_schema(opname, self.version, self.domain) + return True + except: + return False + + def __str__(self) -> str: + return self.domain + + +opset15 = Opset("", 15) + +msdomain1 = Opset("com.microsoft", 1) + +# ONNX ops + + +class Op: + def __init__(self, opset, opname) -> None: + self.opset = opset + self.opname = opname + + def get_schema(self): + return self.opset[self.opname] + + def has_schema(self): + return (self.opname in self.opset) + +# Values fall into the following categories: +# ConstValue: values known at translation-time, mapped to ONNX attributes +# AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes +# Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs + + +class Value: + def __init__(self, val) -> None: + self.value = val + + +class ConstValue(Value): + def __init__(self, val) -> None: + super().__init__(val) + + +class AttrRef(Value): + def __init__(self, name: str, type) -> None: + super().__init__(name) + self.type = type + + +class Dynamic(Value): + def __init__(self, val) -> None: + super().__init__(val)