##### Copyright 2024 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png" height="20px"> Hugging Face to <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png" height="20px"> PyTorch to <img src="https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/ghost.svg" height="20px"> IREE

This notebook uses [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) to export a pretrained [Hugging Face Transformers](https://huggingface.co/docs/transformers/) model to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.

* The pretrained [whisper-small](https://huggingface.co/openai/whisper-small)
  model is showcased here as it is small enough to fit comfortably into a Colab
  notebook. Other pretrained models can be found at
  https://huggingface.co/docs/transformers/index.

## Setup

In [2]:
%%capture
#@title Uninstall existing packages
#   This avoids some warnings when installing specific PyTorch packages below.
!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision

In [3]:
!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0

Looking in indexes: https://download.pytorch.org/whl/test/cpu
Collecting torch==2.3.0
  Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.4/190.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 2.2.1+cu121
    Uninstalling torch-2.2.1+cu121:
      Successfully uninstalled torch-2.2.1+cu121
Successfully installed torch-2.3.0+cpu


In [4]:
!python -m pip install iree-turbine

Collecting iree-turbine
  Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/150.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m143.4/150.4 kB[0m [31m4.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.4/150.4 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Collecting iree-compiler>=20240410.859 (from iree-turbine)
  Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)
  Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m26.4

In [5]:
#@title Report version information
!echo "Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)"

!echo -e "\nInstalled IREE, compiler version information:"
!iree-compile --version

import torch
print("\nInstalled PyTorch, version:", torch.__version__)

Installed iree-turbine, Version: 2.3.0rc20240410

Installed IREE, compiler version information:
IREE (https://iree.dev):
  IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1
  LLVM version 19.0.0git
  Optimized build

Installed PyTorch, version: 2.3.0+cpu


## Load and run whisper-small

Load the pretrained model from https://huggingface.co/openai/whisper-small.

See also:

* Model card: https://huggingface.co/docs/transformers/model_doc/whisper
* Test case in [SHARK-TestSuite](https://github.com/nod-ai/SHARK-TestSuite/): [`pytorch/models/whisper-small/model.py`](https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/pytorch/models/whisper-small/model.py)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# https://huggingface.co/docs/transformers/model_doc/auto
# AutoModelForCausalLM -> WhisperForCausalLM
# AutoTokenizer        -> WhisperTokenizerFast

modelname = "openai/whisper-small"
tokenizer = AutoTokenizer.from_pretrained(modelname)

# Some of the options here affect how the model is exported. See the test cases
# at https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/models
# for other options that may be useful to set.
model = AutoModelForCausalLM.from_pretrained(
    modelname,
    output_attentions=False,
    output_hidden_states=False,
    attn_implementation="eager",
    torchscript=True,
)

# This is just a simple demo to get some data flowing through the model.
# Depending on this model and what input it expects (text, image, audio, etc.)
# this might instead use a specific Processor class. For Whisper,
# WhisperProcessor runs audio input pre-processing and output post-processing.
example_prompt = "Hello world!"
example_encoding = tokenizer(example_prompt, return_tensors="pt")
example_input = example_encoding["input_ids"].cpu()
example_args = (example_input,)

Test exporting using [`torch.export()`](https://pytorch.org/docs/stable/export.html#torch.export.export). If `torch.export` works, `aot.export()` from Turbine should work as well.

In [None]:
import torch
exported_program = torch.export.export(model, example_args)

Export using the simple [`aot.export()`](https://iree.dev/guides/ml-frameworks/pytorch/#simple-api) API from Turbine.

In [8]:
import shark_turbine.aot as aot
# Note: aot.export() wants the example args to be unpacked.
whisper_compiled_module = aot.export(model, *example_args)

Compile using Turbine/IREE then run the program.

In [9]:
binary = whisper_compiled_module.compile(save_to=None)

import iree.runtime as ireert
config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),
    config,
)

iree_outputs = vm_module.main(example_args[0])
print(iree_outputs[0].to_host())

[[[  5.8126216   3.9667568   4.5749426 ...   2.7658575   2.6436937
     1.5479789]
  [  7.563438    6.0299625   5.1000338 ...   6.4327035   6.101554
     6.434801 ]
  [  0.9380368  -4.4696164  -4.012759  ...  -6.24863    -7.791795
    -6.84537  ]
  [  0.7450911  -3.7631674  -7.4870267 ...  -6.7348223  -6.966235
   -10.022385 ]
  [ -0.9628638  -3.5101964  -6.0158615 ...  -7.116393   -6.7086525
   -10.225711 ]
  [  3.3470955   2.4927258  -3.3042645 ...  -1.5709444  -1.8455245
    -2.9991858]]]


Run the program using native PyTorch to compare outputs.

In [10]:
torch_outputs = model(example_args[0])
print(torch_outputs[0].detach().numpy())

[[[  5.8126183    3.9667587    4.5749483  ...   2.7658575    2.643694
     1.5479784 ]
  [  7.563436     6.029952     5.100036   ...   6.4327083    6.101557
     6.4348083 ]
  [  0.93802685  -4.469646    -4.012787   ...  -6.2486415   -7.7918167
    -6.8453975 ]
  [  0.74507916  -3.763197    -7.487034   ...  -6.734877    -6.966276
   -10.022424  ]
  [ -0.96288276  -3.510221    -6.0158725  ...  -7.1164136   -6.708687
   -10.225745  ]
  [  3.3470666    2.492654    -3.304323   ...  -1.5709934   -1.8455791
    -2.9992423 ]]]
