##### Copyright 2023 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png" height="20px"> PyTorch Just-in-time (JIT) workflows using <img src="https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg" height="20px"> IREE

This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for eager execution within a PyTorch session using [IREE](https://github.com/openxla/iree) and [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.

## Setup

In [2]:
%%capture
#@title Uninstall existing packages
#   This avoids some warnings when installing specific PyTorch packages below.
!python -m pip uninstall -y fastai torchaudio torchdata torchtext

In [3]:
#@title Install SHARK-Turbine following [the instructions](https://github.com/nod-ai/SHARK-Turbine)

# Limit cell height.
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

# Note: the install steps are tweaked here to work in Colab.
![[ -d SHARK-Turbine ]] || git clone https://github.com/nod-ai/SHARK-Turbine.git
!python -m pip install --upgrade -r SHARK-Turbine/requirements.txt
!python -m pip install git+https://github.com/nod-ai/SHARK-Turbine.git --upgrade

<IPython.core.display.Javascript object>

Cloning into 'SHARK-Turbine'...
remote: Enumerating objects: 1093, done.[K
remote: Counting objects: 100% (517/517), done.[K
remote: Compressing objects: 100% (248/248), done.[K
remote: Total 1093 (delta 294), reused 387 (delta 226), pack-reused 576[K
Receiving objects: 100% (1093/1093), 353.87 KiB | 3.69 MiB/s, done.
Resolving deltas: 100% (497/497), done.
Looking in links: https://openxla.github.io/iree/pip-release-links.html
Collecting torch==2.1.0 (from -r SHARK-Turbine/pytorch-cpu-requirements.txt (line 2))
  Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting torchvision (from -r SHARK-Turbine/torchvision-requirements.txt (line 2))
  Downloading torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl (6.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m84.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
#@title Test installs and report version information
from shark_turbine.dynamo.importer import FxImporter  # Test for install success
!cd SHARK-Turbine && echo "Installed SHARK-Turbine from git at https://github.com/nod-ai/SHARK-Turbine/commit/$(git rev-parse HEAD)"

!echo -e "\nInstalled IREE, compiler version information:"
!iree-compile --version

import torch
print("\nInstalled PyTorch, version:", torch.__version__)

Installed SHARK-Turbine from git at https://github.com/nod-ai/SHARK-Turbine/commit/d53f14ba3677c1cb6c25033ebf5629bb0dcdc9de

Installed IREE, compiler version information:
IREE (https://openxla.github.io/iree):
  IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447
  LLVM version 18.0.0git
  Optimized build

Installed PyTorch, version: 2.1.0+cu121


## Sample JIT workflow

1. Define a program using `torch.nn.Module`
2. Run `torch.compile(module, backend="turbine_cpu")`
3. Use the resulting `OptimizedModule` as you would a regular `nn.Module`

Useful documentation:

* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation
* [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) as an interface to TorchDynamo and optimizing using backend compilers like Turbine

In [5]:
torch.manual_seed(0)

class LinearModule(torch.nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))

  def forward(self, input):
    return (input @ self.weight) + self.bias

linear_module = LinearModule(4, 3)

In [6]:
opt_linear_module = torch.compile(linear_module, backend="turbine_cpu")
print("Compiled module using Turbine. New module type is", type(opt_linear_module))

Compiled module using Turbine. New module type is <class 'torch._dynamo.eval_frame.OptimizedModule'>


In [7]:
args = torch.randn(4)
turbine_output = opt_linear_module(args)

print("Weight:", linear_module.weight)
print("Bias:", linear_module.bias)
print("Args:", args)
print("Output:", turbine_output)

module {
  func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[4],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {
    %int0 = torch.constant.int 0
    %0 = torch.aten.unsqueeze %arg2, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
    %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>
    %int0_0 = torch.constant.int 0
    %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>
    %int1 = torch.constant.int 1
    %3 = torch.aten.add.Tensor %2, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
    return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>
  }
}

#map = affine_map<(d0) -> (d0)>
module {
  func.func @main(%arg0: tensor<4x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<4xf32>) -> (tensor<3xf32>, tensor<1x4xf32>)

Weight: Parameter containing:
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]], requires_grad=True)
Bias: Parameter containing:
tensor([-0.8567,  1.1006, -1.0712], requires_grad=True)
Args: tensor([ 0.1227, -0.5663,  0.3731, -0.8920])
Output: tensor([-0.4792,  2.5237, -0.9772], grad_fn=<CompiledFunctionBackward>)


