-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add scalar instantiation and casting (like int(val) or np.float16(val…
…)) (#315) * Support instantiation of abstract scalar types. * Replace abstract types with int32 placeholder in relay backend. * Move apply inference into a new macro `call_object`.
- Loading branch information
1 parent
465ed3f
commit a2535fc
Showing
8 changed files
with
188 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
Implementation of macro `call_object(obj, *args)`. | ||
Receive and call syntax `obj(*args)` and replace it with | ||
relevant myia operations. | ||
""" | ||
from myia.utils.errors import MyiaTypeError | ||
|
||
from .. import operations, xtype | ||
from ..lib import ( | ||
DATA, | ||
AbstractClassBase, | ||
AbstractError, | ||
AbstractScalar, | ||
AbstractType, | ||
macro, | ||
type_to_abstract, | ||
) | ||
from ..operations import primitives as P | ||
|
||
|
||
@macro | ||
async def call_object(info, fn, *n_args): | ||
"""Replace call syntax `fn(*n_args)` with relevant myia operations.""" | ||
fn_node = info.argrefs[0].node | ||
arg_nodes = (ref.node for ref in info.argrefs[1:]) | ||
g = info.graph | ||
fn = await fn.get() | ||
|
||
if isinstance(fn, AbstractType): | ||
# Replace abstract type instantiation with | ||
# either a cast for abstract scalars, or | ||
# a make_record for all other cases. | ||
val = fn.xvalue() | ||
cls = type_to_abstract(val) | ||
if isinstance(cls, AbstractScalar): | ||
typ = cls.xtype() | ||
if issubclass(typ, xtype.Number): | ||
newcall = g.apply(P.scalar_cast, *arg_nodes, cls) | ||
elif typ is xtype.Bool: | ||
newcall = g.apply(operations.bool, *arg_nodes) | ||
else: | ||
raise MyiaTypeError(f'Cannot compile typecast to {typ}') | ||
else: | ||
newfn = g.apply(P.partial, P.make_record, val) | ||
newcall = g.apply(newfn, *arg_nodes) | ||
return newcall | ||
|
||
elif isinstance(fn, AbstractError): | ||
raise MyiaTypeError( | ||
f'Trying to call a function with type ' | ||
f'{fn.xvalue()} {fn.values[DATA] or ""}.' | ||
) | ||
|
||
elif isinstance(fn, AbstractClassBase): | ||
newfn = g.apply(operations.getattr, fn_node, '__call__') | ||
newcall = g.apply(newfn, *arg_nodes) | ||
return newcall | ||
|
||
else: | ||
raise MyiaTypeError(f'Myia does not know how to call {fn}') | ||
|
||
|
||
__operation_defaults__ = { | ||
'name': 'call_object', | ||
'registered_name': 'call_object', | ||
'mapping': call_object, | ||
'python_implementation': None, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters