# Tutorial: Exporting StableHLO from PyTorch

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][pytorch-tutorial-colab]
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][pytorch-tutorial-kaggle]

_Intro to the [`torch_xla.stablehlo`](https://github.com/pytorch/xla/blob/main/docs/stablehlo.md) module._

## Tutorial Setup

### Install required dependencies

We'll be using `torch` and `torchvision` to get a `resnet18` model, and `torch_xla` to export to StableHLO.

[pytorch-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb
[pytorch-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb

In [None]:
!pip install torch_xla torch torchvision

## Export PyTorch model to StableHLO

The general set of steps for exporting a PyTorch model to StableHLO is:
1. Export using PyTorch's `torch.export` API.
2. Convert exported FX Graph to StableHLO using `torch_xla.stablehlo` APIs.

### Export model to FX graph using `torch.export`

This step uses entirely vanilla PyTorch APIs to export a `resnet18` model from `torchvision`. Sample inputs are required for graph tracing, we use a `tensor<4x3x224x224xf32>` in this case.

In [7]:
import torch
import torchvision
from torch.export import export

resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
sample_input = (torch.randn(4, 3, 224, 224), )
exported = export(resnet18, sample_input)

### Export FX Graph to StableHLO using TorchXLA

Once we have an exported FX graph, we can convert to StableHLO using the `torch_xla.stablehlo` module. In this case we'll use `exported_program_to_stablehlo`.

In [8]:
from torch_xla.stablehlo import exported_program_to_stablehlo

stablehlo_program = exported_program_to_stablehlo(exported)
print(stablehlo_program.get_stablehlo_text('forward')[0:4000],"\n...")

module @IrToHlo.508 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64

### Export with dynamic batch dimension

_This is a new feature and will work after 2.3 release cut, or if using `torch_xla` nightly. Once PyTorch/XLA 2.3 is released, this will be converted into a running example. Using the nightly `torch` and `torch_xla` will likely lead to notebook failures in the meantime._

Dynamic batch dimensions can be specified as a part of the inital `torch.export` step. The FX Graph's symint information is used to export to dynamic StableHLO.

In this example, we specify that dim 0 of the sample input is dynamic, which propagates shape using a `tensor<?x3x224x224xf32>`.

```python
from torch.export import Dim

# Create a dynamic batch size, for the first dimension of the input
batch = Dim("batch", max=15)
dynamic_shapes = ({0: batch},)
dynamic_export = export(resnet18, sample_input, dynamic_shapes=dynamic_shapes)
dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)
print(dynamic_stablehlo.get_stablehlo_text('forward')[0:5000],"\n...")
```

### Saving and reloading StableHLO

The `StableHLOGraphModule` has methods to `save` and `load` StableHLO artifacts. This stores StableHLO portable bytecode artifacts which have full backward compatiblity guarantees.

In [9]:
from torch_xla.stablehlo import StableHLOGraphModule

# Save to tmp
stablehlo_program.save('/tmp/stablehlo_dir')
!ls /tmp/stablehlo_dir
!ls /tmp/stablehlo_dir/functions

# Reload and execute - Stable serialization, forward / backward compatible.
reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')
print(reloaded(sample_input[0]))

constants  data  functions
forward.bytecode  forward.meta	forward.mlir
tensor([[-0.7431, -2.5955, -0.0718,  ..., -1.4230,  1.5928, -0.7693],
        [ 0.6199, -1.5941, -0.9018,  ...,  0.2452,  0.6159,  2.4765],
        [-3.0291,  2.2174, -2.2809,  ..., -0.9081,  1.8253,  2.2141],
        [ 0.9318, -0.0566,  0.8561,  ..., -0.1650,  0.7882, -0.2697]],
       device='xla:0')


## Export to SavedModel

It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via [TF Serving](https://github.com/tensorflow/serving).

PyTorch/XLA makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed.

### Install latest TF

SavedModel definition lives in TF, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`.

In [None]:
!pip install tensorflow-cpu

### Export to SavedModel using `torch_xla.tf_saved_model_integration`

PyTorch/XLA provides a simple API for exporting StableHLO in a SavedModel `save_torch_module_as_tf_saved_model`. This uses the `torch.export` and `torch_xla.stablehlo.exported_program_to_stablehlo` functions under the hood.

The input to the API is a PyTorch model, we'll use the same `resnet18` from the previous examples.

In [10]:
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
import tensorflow as tf

save_torch_module_as_tf_saved_model(
    resnet18,         # original pytorch torch.nn.Module
    sample_input,     # sample inputs used to trace
    '/tmp/resnet_tf'  # directory for tf.saved_model
)

!ls /tmp/resnet_tf/

assets	fingerprint.pb	saved_model.pb	variables


### Reload and call the SavedModel

Now we can load that SavedModel and compile using our `sample_input` from a previous example.

_Note: the restored model does *not* require PyTorch or PyTorch/XLA to run, just XLA._

In [12]:
loaded_m = tf.saved_model.load('/tmp/resnet_tf')
print(loaded_m.f(tf.constant(sample_input[0].numpy())))

[<tf.Tensor: shape=(4, 1000), dtype=float32, numpy=
array([[-0.74313045, -2.595488  , -0.0718156 , ..., -1.4230162 ,
         1.5928375 , -0.7693139 ],
       [ 0.61994374, -1.594082  , -0.901797  , ...,  0.2451565 ,
         0.6159245 ,  2.4764667 ],
       [-3.029084  ,  2.2174084 , -2.2808676 , ..., -0.90810233,
         1.8252805 ,  2.214109  ],
       [ 0.93176216, -0.0566061 ,  0.8560745 , ..., -0.16496754,
         0.7881946 , -0.26973075]], dtype=float32)>]


# Common Troubleshooting

Most issues in PyTorch to StableHLO require a GH ticket as a next step. Teams are generally quick to help resolve issues.

- Issues in `torch.export`: These need to be resolved in upstream PyTorch.
- Issues in `torch_xla.stablehlo`: Open a ticket on pytorch/xla repo.