Skip to content

Commit

Permalink
Add scalar instantiation and casting (like int(val) or np.float16(val…
Browse files Browse the repository at this point in the history
…)) (#315)

* Support instantiation of abstract scalar types.
* Replace abstract types with int32 placeholder in relay backend.
* Move apply inference into a new macro `call_object`.
  • Loading branch information
notoraptor authored and breuleux committed Dec 9, 2019
1 parent 465ed3f commit a2535fc
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 25 deletions.
24 changes: 2 additions & 22 deletions myia/abstract/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@
from .amerge import amerge, bind
from .data import (
ANYTHING,
DATA,
TYPE,
VALUE,
AbstractClassBase,
AbstractError,
AbstractFunction,
AbstractJTagged,
AbstractKeywordArgument,
AbstractScalar,
AbstractTuple,
AbstractType,
AbstractValue,
DummyFunction,
Function,
Expand Down Expand Up @@ -295,27 +291,11 @@ async def infer_apply(self, ref):
fn = await fn_ref.get()
argrefs = [self.ref(node, ctx) for node in n_args]

if isinstance(fn, AbstractType):
if not isinstance(fn, AbstractFunction):
g = ref.node.graph
newfn = g.apply(P.partial, P.make_record, fn.xvalue())
newcall = g.apply(newfn, *n_args)
newcall = g.apply(operations.call_object, n_fn, *n_args)
return await self.reroute(ref, self.ref(newcall, ctx))

elif isinstance(fn, AbstractError):
raise MyiaTypeError(
f'Trying to call a function with type '
f'{fn.xvalue()} {fn.values[DATA] or ""}.'
)

elif isinstance(fn, AbstractClassBase):
g = ref.node.graph
newfn = g.apply(operations.getattr, fn_ref.node, '__call__')
newcall = g.apply(newfn, *n_args)
return await self.reroute(ref, self.ref(newcall, ctx))

elif not isinstance(fn, AbstractFunction):
raise MyiaTypeError(f'Myia does not know how to call {fn}')

infs = [self.get_inferrer_for(poss)
for poss in await fn.get()]

Expand Down
9 changes: 7 additions & 2 deletions myia/compile/backends/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,15 @@ def from_backend_value(self, v, t):

def to_backend_value(self, v, t):
"""Convert an intermediate value to a backend value."""
if (isinstance(t, (abstract.AbstractError, abstract.AbstractType))
if (isinstance(t, abstract.AbstractError)
or v is abstract.DEAD):
return None
if isinstance(t, abstract.AbstractArray):
elif isinstance(t, abstract.AbstractType):
# Handle abstract types.
# Return None if type does not match any torch type.
myia_type = t.xvalue().xtype()
return _type_map.get(myia_type, None)
elif isinstance(t, abstract.AbstractArray):
return self.from_numpy(v)
elif isinstance(t, abstract.AbstractScalar):
if issubclass(t.values[abstract.TYPE],
Expand Down
7 changes: 7 additions & 0 deletions myia/compile/backends/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from itertools import accumulate

import numpy as np
import tvm
from tvm import relay
from tvm.relay import adt
Expand Down Expand Up @@ -707,6 +708,12 @@ def convert_tagged(self, v, t):
cst = self.cst_conv.convert_tagged(v, t)
return self.intrp.evaluate(cst)

def convert_type(self, v, t):
# abstract type will be replaced with an integer type as placeholder
# (see to_relay_type(AbstractType), so we must return an integer
# of same type here.
return np.int32(0)


class RelayOutputConverter(Converter):
"""Convert values from Relay."""
Expand Down
10 changes: 10 additions & 0 deletions myia/compile/backends/relay_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AbstractScalar,
AbstractTaggedUnion,
AbstractTuple,
AbstractType,
TypedPrimitive,
VirtualFunction,
broaden,
Expand Down Expand Up @@ -184,6 +185,15 @@ def to_relay_type(self, a: AbstractScalar):
return relay.ty.scalar_type(type_to_np_dtype(tp))


@overload # noqa: F811
def to_relay_type(self, a: AbstractType):
# Abstract types are not currently used in the graph,
# they are replaced with other calls,
# and appear here just as unused graph parameters.
# So, let's just replace them with an integer type as placeholder.
return relay.ty.scalar_type('int32')


@overload # noqa: F811
def to_relay_type(self, a: AbstractTuple):
return relay.ty.TupleType([self(e) for e in a.elements])
Expand Down
5 changes: 5 additions & 0 deletions myia/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@
defaults='myia.operations.prim_broadcast_shape'
)

call_object = Operation(
name='call_object',
defaults='myia.operations.macro_call_object'
)

casttag = Operation(
name='casttag',
defaults='myia.operations.prim_casttag'
Expand Down
69 changes: 69 additions & 0 deletions myia/operations/macro_call_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Implementation of macro `call_object(obj, *args)`.
Receive and call syntax `obj(*args)` and replace it with
relevant myia operations.
"""
from myia.utils.errors import MyiaTypeError

from .. import operations, xtype
from ..lib import (
DATA,
AbstractClassBase,
AbstractError,
AbstractScalar,
AbstractType,
macro,
type_to_abstract,
)
from ..operations import primitives as P


@macro
async def call_object(info, fn, *n_args):
"""Replace call syntax `fn(*n_args)` with relevant myia operations."""
fn_node = info.argrefs[0].node
arg_nodes = (ref.node for ref in info.argrefs[1:])
g = info.graph
fn = await fn.get()

if isinstance(fn, AbstractType):
# Replace abstract type instantiation with
# either a cast for abstract scalars, or
# a make_record for all other cases.
val = fn.xvalue()
cls = type_to_abstract(val)
if isinstance(cls, AbstractScalar):
typ = cls.xtype()
if issubclass(typ, xtype.Number):
newcall = g.apply(P.scalar_cast, *arg_nodes, cls)
elif typ is xtype.Bool:
newcall = g.apply(operations.bool, *arg_nodes)
else:
raise MyiaTypeError(f'Cannot compile typecast to {typ}')
else:
newfn = g.apply(P.partial, P.make_record, val)
newcall = g.apply(newfn, *arg_nodes)
return newcall

elif isinstance(fn, AbstractError):
raise MyiaTypeError(
f'Trying to call a function with type '
f'{fn.xvalue()} {fn.values[DATA] or ""}.'
)

elif isinstance(fn, AbstractClassBase):
newfn = g.apply(operations.getattr, fn_node, '__call__')
newcall = g.apply(newfn, *arg_nodes)
return newcall

else:
raise MyiaTypeError(f'Myia does not know how to call {fn}')


__operation_defaults__ = {
'name': 'call_object',
'registered_name': 'call_object',
'mapping': call_object,
'python_implementation': None,
}
14 changes: 14 additions & 0 deletions myia/xtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ class UniverseType(Object):
register_serialize(f64, 'f64')
register_serialize(UniverseType, 'universe_type')

register_serialize(numpy.bool, 'numpy_bool')
register_serialize(numpy.int, 'numpy_int')
register_serialize(numpy.float, 'numpy_float')
register_serialize(numpy.int8, 'numpy_int8')
register_serialize(numpy.int16, 'numpy_int16')
register_serialize(numpy.int32, 'numpy_int32')
register_serialize(numpy.int64, 'numpy_int64')
register_serialize(numpy.uint8, 'numpy_uint8')
register_serialize(numpy.uint16, 'numpy_uint16')
register_serialize(numpy.uint32, 'numpy_uint32')
register_serialize(numpy.uint64, 'numpy_uint64')
register_serialize(numpy.float16, 'numpy_float16')
register_serialize(numpy.float32, 'numpy_float32')
register_serialize(numpy.float64, 'numpy_float64')

DTYPE_TO_MTYPE = dict(
int8=Int[8],
Expand Down
75 changes: 74 additions & 1 deletion tests/operations/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from myia.utils.errors import MyiaTypeError
from myia.xtype import f16, f32, f64, i16, i32, i64, u32
from myia.xtype import Bool, Nil, f16, f32, f64, i8, i16, i32, i64, u32, u64

from ..common import (
Ty,
Expand Down Expand Up @@ -117,3 +117,76 @@ def test_full(shape, value, dtype):
)
def test_infer_full(shape, value, dtype):
return np.full(shape, value, dtype)


@mt(
# we could not cast to a Nil,
infer(Ty(Nil), i64, result=MyiaTypeError),
# We could cast to a Bool,
infer(Ty(Bool), i64, result=Bool),
# we could create an int8 from any floating.
infer(Ty(np.int8), f16, result=i8),
infer(Ty(np.int8), f32, result=i8),
infer(Ty(np.int8), f64, result=i8),
# we could cast an int64 to any lower precision integer.
infer(Ty(np.int8), i64, result=i8),
infer(Ty(np.int16), i64, result=i16),
infer(Ty(np.int32), i64, result=i32),
infer(Ty(np.int64), i64, result=i64),
# we could instantiate an uint from an int, and vice versa
infer(Ty(np.int64), u64, result=i64),
infer(Ty(np.uint64), i64, result=u64),
)
def test_infer_scalar_cast(dtype, value):
return dtype(value)


@mt(
# test each scalar type
run(np.uint8, 0, result=0),
run(np.uint16, 0, result=0),
run(np.uint32, 0, result=0),
run(np.uint64, 0, result=0),
run(np.int8, 0, result=0),
run(np.int16, 0, result=0),
run(np.int32, 0, result=0),
run(np.int64, 0, result=0),
run(np.float16, 0, result=0),
run(np.float32, 0, result=0),
run(np.float64, 0, result=0),
run(np.bool, 0, result=0),
run(np.int, 0, result=0),
run(np.float, 0, result=0),
run(np.double, 0, result=0),
run(bool, 0, result=0),
run(int, 0, result=0),
run(float, 0, result=0),
# test bool
run(np.bool, 0.0, result=False),
run(np.bool, 1, result=True),
run(np.bool, 1, result=1),
run(np.bool, -1, result=1),
run(np.bool, -1.23456, result=1),
run(np.bool, -1.23456, result=1),
# test uint8
run(np.uint8, 0, result=0),
run(np.uint8, 255, result=255),
run(np.uint8, 256, result=0),
run(np.uint8, 257, result=1),
run(np.uint8, -1, result=255),
run(np.uint8, -1.5, result=255), # -1.5 => -1 => forced to 255
run(np.uint8, 255.123456789, result=255),
# test int8
run(np.int8, -128, result=-128),
run(np.int8, 127, result=127),
run(np.int8, 128, result=-128),
run(np.int8, 129, result=-127),
run(np.int8, -129, result=127),
# test float16
run(np.float16, 1, result=1),
run(np.float16, -1.1, result=np.float16(-1.1)),
run(np.float16, -1.23456789, result=-1.234375),
broad_specs=(False, False)
)
def test_scalar_cast(dtype, value):
return dtype(value)

0 comments on commit a2535fc

Please sign in to comment.