# `metadsl` + `uarray`

First, let's installed the latest uarray:

In [1]:
!pip install numpy
!pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp37-cp37m-linux_x86_64.whl
!pip install torchvision
!pip install -U git+https://github.com/Quansight-Labs/uarray.git

Collecting git+https://github.com/Quansight-Labs/uarray.git
  Cloning https://github.com/Quansight-Labs/uarray.git to /private/var/folders/m7/t8dvwtnn32z84333p845tly40000gn/T/pip-req-build-vv251_40
Building wheels for collected packages: uarray
  Building wheel for uarray (setup.py) ... [?25ldone
[?25h  Stored in directory: /private/var/folders/m7/t8dvwtnn32z84333p845tly40000gn/T/pip-ephem-wheel-cache-6m7r4pvj/wheels/3d/1a/bf/60f787ed8f0ac071de28d869eb644793856923ac4688c14544
Successfully built uarray
Installing collected packages: uarray
  Found existing installation: uarray 0.4+168.g345664c
    Uninstalling uarray-0.4+168.g345664c:
      Successfully uninstalled uarray-0.4+168.g345664c
Successfully installed uarray-0.4+228.g09d6ad2


`uarray`  and `unumpy` provide ways to execute on different backends like Torch, NumPy, XND.

`metadsl` provides a way to optimize your computation befor executing.

Here, we will show how we can integrate them to allow users to build up an expression with `unumpy`, optimize it with `metadsl`, and then execute it with `unumpy`.

In [1]:
import uarray
import unumpy
import typing
import unumpy.multimethods

In [2]:
import metadsl

Let's map certain unumpy methods to metadsl functions:

In [3]:
metadsl_backend = uarray.Backend()
uarray.register_backend(metadsl_backend)

class Array(metadsl.Instance):
    pass


@metadsl.call(lambda start, stop, stride: Array)
def arange(start: int, stop: int, stride: int) -> Array:
    ...

@metadsl.call(lambda shape: Array)
def zeros(shape: typing.Tuple[int]) -> Array:
    ...

@metadsl.call(lambda left, right: Array)
def add(left: Array, right: Array) -> Array:
    ...

@metadsl.call(lambda a: Array)
def sum(a: Array) -> Array:
    ...

@metadsl.call(lambda: Array)
def zero() -> Array:
    ...
    
METADSL_TO_UNUMPY = {
    arange: unumpy.arange,
    zeros: unumpy.zeros,
    add: unumpy.add,
    sum: unumpy.sum 
}
for m, u in METADSL_TO_UNUMPY.items():
    uarray.register_implementation(u, metadsl_backend)(m)
unumpy.ndarray.register_convertor(metadsl_backend, lambda i: i)



Now we can use them to build up some expression in metadsl:

In [4]:
with uarray.set_backend(metadsl_backend):
    left = unumpy.arange(0, 10, 2)
    right = unumpy.sum(unumpy.zeros(10))


In [5]:
str(left)

'arange(0, 10, 2)'

In [6]:
str(right)

'sum(zeros(10))'

Now let's simplify summing zeros with just a zero:

In [7]:
simplifications = metadsl.RulesRepeatFold()
simplify = metadsl.RuleApplier(simplifications)

@simplifications.append
@metadsl.pure_rule(None)
def _sum_zeros_zero(shape: typing.Tuple[int]):
    return (
        sum(zeros(shape)),
        zero(),
    )


In [31]:
print(str(right))
print(str(simplify(right)))

sum(zeros(10))
zero()


In [27]:
def execute(a):
    if not isinstance(a, Array):
        return a
    return METADSL_TO_UNUMPY[a._call.function](*map(execute, a._call.args))

In [29]:
from unumpy.numpy_backend import NumpyBackend

with uarray.set_backend(NumpyBackend):
    print(execute(right))

0.0


In [30]:
from unumpy.torch_backend import TorchBackend

with uarray.set_backend(TorchBackend):
    print(execute(right))

tensor(0.)
