In [1]:
# standard imports
import torch
from torch_mlir.eager_mode import torch_mlir_tensor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# eager mode imports
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend

The simplest way of using Eager Mode (through IREE) requires setting a "backend":

In [3]:
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("cpu")

and wrapping all your `torch.Tensor`s:

In [4]:
NUM_ITERS = 10

t = torch.ones((10, 10))
u = 2 * torch.ones((10, 10))

tt = TorchMLIRTensor(t)
print(tt)
uu = TorchMLIRTensor(u)
print(uu)

TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)


`TorchMLIRTensor` is a "tensor wrapper subclass" (more info [here](https://github.com/albanD/subclass_zoo)) that keeps the IREE `DeviceArray` in a field `elem`:

In [5]:
for i in range(NUM_ITERS):
    yy = tt + uu
    print(type(yy))
    print(yy.elem.to_host())
    yy = tt * uu
    print(type(yy))
    print(yy.elem.to_host())

<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]
<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 

If you have a GPU (and CUDA installed) that works too (you can verify by having `watch -n1 nvidia-smi` up in a terminal while running the next cell):

In [6]:
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("gpu")

t = torch.ones((10, 10))
u = 2 * torch.ones((10, 10))

tt = TorchMLIRTensor(t)
print(tt)
uu = TorchMLIRTensor(u)
print(uu)

yy = tt + uu
print(yy.elem.to_host())
yy = tt * uu
print(yy.elem.to_host())

TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]


There is a convenience class `SharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:

In [7]:
# eager mode RAII
from shark.shark_runner import SharkEagerMode

shark_eager_mode = SharkEagerMode("cpu")

t = torch.ones((10, 10))
u = torch.ones((10, 10))

print(t)
print(u)

for i in range(NUM_ITERS):
    yy = t + u
    print(type(yy))
    print(yy.elem.to_host())
    yy = t * u
    print(type(yy))
    print(yy.elem.to_host())

TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.

The `SharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a "deleter" that runs when an instantiation (of `SharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `SharkEagerMode`, or switch backends, you need to `del` the instance:

In [8]:
del shark_eager_mode
shark_eager_mode = SharkEagerMode("cuda")

t = torch.ones((10, 10))
u = torch.ones((10, 10))

print(t)
print(u)

yy = t + u
print(type(yy))
print(yy.elem.to_host())
yy = t * u
print(type(yy))
print(yy.elem.to_host())

TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)
<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.