##### Copyright 2023 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png" height="20px"> PyTorch Ahead-of-time (AOT) export workflows using <img src="https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg" height="20px"> IREE

This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for export from a PyTorch session to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.

iree-turbine contains both a "simple" AOT exporter and an underlying advanced
API for complicated models and full feature availability. This notebook shows
some of the features available in the "advanced" toolkit.

## Setup

In [2]:
%%capture
#@title Uninstall existing packages
#   This avoids some warnings when installing specific PyTorch packages below.
!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision

In [3]:
#@title Install Pytorch 2.3.0 (for CPU)
!python -m pip install --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0

Looking in indexes: https://download.pytorch.org/whl/test/cpu
Collecting torch==2.3.0
  Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.4/190.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 2.5.0+cu121
    Uninstalling torch-2.5.0+cu121:
      Successfully uninstalled torch-2.5.0+cu121
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
timm 1.0.11 requires torchvision, which is not installed.[0m[31m
[0mSuccessfully installed torch-2.3.0+cpu


In [4]:
#@title Install iree-turbine
!python -m pip install iree-turbine

Collecting iree-turbine
  Downloading iree_turbine-2.5.0-py3-none-any.whl.metadata (5.7 kB)
Collecting iree-compiler (from iree-turbine)
  Downloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (615 bytes)
Collecting iree-runtime (from iree-turbine)
  Downloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (798 bytes)
Downloading iree_turbine-2.5.0-py3-none-any.whl (271 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m271.3/271.3 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (70.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.7/70.7 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl (8.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.0/8.0 MB[0m [31m59.7 MB/s

In [5]:
#@title Report version information
!echo "Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)"

!echo -e "\nInstalled IREE, compiler version information:"
!iree-compile --version

import torch
print("\nInstalled PyTorch, version:", torch.__version__)

Installed iree-turbine, Version: 2.5.0

Installed IREE, compiler version information:
IREE (https://iree.dev):
  IREE compiler version 20241104.1068 @ 9c85e30df30d6efcf68a7a1b594e89322bd6085d
  LLVM version 20.0.0git
  Optimized build

Installed PyTorch, version: 2.3.0+cpu


## Advanced AOT toolkit examples

1. Define a PyTorch program using `torch.nn.Module`
2. Define the API and properties of that program by using `aot.CompiledModule`
3. Export the program using `aot.export()`
4. Compile to a deployable artifact
  * a: By staying within a Python session
  * b: By outputting MLIR and continuing using native tools

Useful documentation:

* [IREE PyTorch guide](https://iree.dev/guides/ml-frameworks/pytorch/)
* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation
* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)

In [6]:
#@title 1. Define a program using `torch.nn.Module`
torch.manual_seed(0)

class LinearModule(torch.nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))

  def forward(self, input):
    return (input @ self.weight) + self.bias

linear_module = LinearModule(4, 3)

In [8]:
#@title 2. Define the API and properties of that program by using aot.CompiledModule

import iree.turbine.aot as aot

example_weight = torch.randn(4, 3)
example_bias = torch.randn(3)

class CompiledLinearModule(aot.CompiledModule):
  params = aot.export_parameters(linear_module, mutable=True)
  compute = aot.jittable(linear_module.forward)

  def main(self, x=aot.AbstractTensor(4)):
    return self.compute(x)

  def get_weight(self):
    return self.params["weight"]

  def set_weight(self, weight=aot.abstractify(example_weight)):
    self.params["weight"] = weight

  def get_bias(self):
    return self.params["bias"]

  def set_bias(self, bias=aot.abstractify(example_bias)):
    self.params["bias"] = bias

In [9]:
#@title 3. Export the program using `aot.export()`
export_output = aot.export(CompiledLinearModule)

In [10]:
#@title 4a. Compile fully to a deployable artifact, in our existing Python session

# Staying in Python gives the API a chance to reuse memory, improving
# performance when compiling large programs.

compiled_binary = export_output.compile(save_to=None)

# Use the IREE runtime API to test the compiled program.
import numpy as np
import iree.runtime as ireert

config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
    config,
)

input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
result = vm_module.main(input)
print(result.to_host())

[ 1.4178505 -1.2343317 -7.4767942]


In [None]:
#@title 4b. Output MLIR then continue from Python or native tools later

# Leaving Python allows for file system checkpointing and grants access to
# native development workflows.

mlir_file_path = "/tmp/linear_module_pytorch.mlirbc"
vmfb_file_path = "/tmp/linear_module_pytorch_llvmcpu.vmfb"

export_output.print_readable()
export_output.save_mlir(mlir_file_path)

!iree-compile --iree-input-type=torch --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-cpu=host {mlir_file_path} -o {vmfb_file_path}
!iree-run-module --module={vmfb_file_path} --device=local-task --function=main --input="4xf32=[1.0, 2.0, 3.0, 4.0]"

module @compiled_linear {
  util.global private mutable @_params.weight = dense_resource<_params.weight> : tensor<4x3xf32>
  util.global private mutable @_params.bias = dense_resource<_params.bias> : tensor<3xf32>
  func.func @main(%arg0: tensor<4xf32>) -> tensor<3xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<4xf32> -> !torch.vtensor<[4],f32>
    %1 = call @forward(%0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32>
    %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3],f32> -> tensor<3xf32>
    return %2 : tensor