<a href="https://colab.research.google.com/github/google/jaxonnxruntime/blob/call_torch/docs/experimental_call_torch_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline


# jaxonnxruntime call_torch Tutorial
**Author:** John Zhang


Here we introduce the call_torch API which can seamlessly translate PyTorch models into JAX functions. This integration unites PyTorch with the extensive JAX software ecosystem and harnesses the power of XLA hardware (TPU/GPU/CPU and openXLA ), enhancing cross-framework collaboration and performance potential




In [2]:
!pip install git+https://github.com/google/jaxonnxruntime.git


Collecting git+https://github.com/google/jaxonnxruntime.git
  Cloning https://github.com/google/jaxonnxruntime.git to /tmp/pip-req-build-yj76dq13
  Running command git clone --filter=blob:none --quiet https://github.com/google/jaxonnxruntime.git /tmp/pip-req-build-yj76dq13
  Resolved https://github.com/google/jaxonnxruntime.git to commit 2b95bac67150f865222a4bda87b66c099b64ae59
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jaxonnxruntime
  Building wheel for jaxonnxruntime (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jaxonnxruntime: filename=jaxonnxruntime-0.3.0-py3-none-any.whl size=177972 sha256=caba7590e9938cbca5365c19348155ab3d42c130fc1b19c45953963089c7066f
  Stored in directory: /tmp/pip-ephem-wheel-cache-jf3xlcq6/wheels/43/d3/a6/40189cf2b24db631a69157

In [3]:
!pip install onnx torch

Collecting onnx
  Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.14.0


## Basic Usage

Generally, we describe all models with format.  We use JAX PyTree data structure for any type model parameters and model inputs.Broadly, our approach involves characterizing all models using a standardized format. This entails employing the JAX PyTree data structure to encapsulate model parameters and inputs of varying types.


In [4]:
import torch
import jax
from jaxonnxruntime.experimental import call_torch

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

torch_inputs =(torch.randn(10, 10), torch.randn(10, 10))
torch_module = torch.jit.trace(foo, torch_inputs)

print("torch_output: ", torch_module(*torch_inputs))


torch_output:  tensor([[ 9.2530e-01,  4.5873e-01,  6.4589e-01,  5.6415e-01,  1.6490e+00,
          8.2248e-01,  8.9589e-01, -5.7184e-01,  1.0232e+00,  5.1742e-01],
        [ 9.5014e-02,  1.3660e+00,  1.3621e+00,  9.0590e-01,  8.4838e-02,
          1.2079e-01,  1.7469e+00,  1.1094e+00,  1.1676e+00,  1.6560e+00],
        [ 1.2260e+00,  1.7275e+00,  1.2192e+00, -9.5381e-01,  1.1586e+00,
          1.3536e-01, -8.1443e-01,  3.7343e-01,  1.5365e+00,  1.2673e+00],
        [ 1.9366e+00,  4.1828e-01,  7.5243e-01, -2.6371e-01, -1.1587e-03,
          1.8683e+00,  7.5635e-01, -6.5726e-01,  1.7267e+00,  1.4934e+00],
        [ 1.7448e-01,  1.2264e+00,  1.5650e+00, -1.1248e-01,  8.0965e-01,
          6.4813e-01,  2.7031e-01, -2.6631e-01,  4.7319e-02,  3.2769e-02],
        [ 4.5132e-01,  1.9266e+00,  1.6291e+00,  4.6194e-01, -1.2171e-01,
          1.8986e+00,  1.3777e-01,  4.9093e-01, -3.5940e-01,  8.6310e-01],
        [ 4.1073e-01,  7.4641e-01,  8.6454e-02,  8.2659e-01,  1.5467e+00,
          9.5625e

In [5]:
jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)
jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)
print("jax_output:", jax_fn(jax_params, jax_inputs))



verbose: False, log level: Level.ERROR

jax_output: [DeviceArray([[ 9.2529660e-01,  4.5873073e-01,  6.4588535e-01,
               5.6415093e-01,  1.6489711e+00,  8.2248122e-01,
               8.9588922e-01, -5.7184368e-01,  1.0232372e+00,
               5.1741624e-01],
             [ 9.5014393e-02,  1.3659939e+00,  1.3621219e+00,
               9.0589929e-01,  8.4837854e-02,  1.2079075e-01,
               1.7469316e+00,  1.1094404e+00,  1.1675845e+00,
               1.6560377e+00],
             [ 1.2259831e+00,  1.7274783e+00,  1.2191784e+00,
              -9.5381111e-01,  1.1585734e+00,  1.3535953e-01,
              -8.1442708e-01,  3.7343296e-01,  1.5364575e+00,
               1.2673373e+00],
             [ 1.9366202e+00,  4.1827971e-01,  7.5243413e-01,
              -2.6371196e-01, -1.1587143e-03,  1.8683007e+00,
               7.5634706e-01, -6.5725970e-01,  1.7267224e+00,
               1.4934300e+00],
             [ 1.7448190e-01,  1.2263532e+00,  1.5649725e+00,
              -1.

*We* can also take ``torch.nn.Module``.



In [6]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

torch_module = MyModule()
torch_inputs = (torch.randn(10, 100), )

In [7]:
torch_module.eval()
jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)
jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)
print("jax_output:", jax_fn(jax_params, jax_inputs))

verbose: False, log level: Level.ERROR

jax_output: [DeviceArray([[0.12982486, 0.481035  , 0.3622095 , 0.7990376 , 0.607528  ,
              0.        , 0.        , 0.2921514 , 0.56446004, 0.        ],
             [0.        , 0.57425183, 0.        , 0.9224024 , 0.        ,
              0.39311224, 0.11618385, 0.        , 0.6319629 , 0.29966408],
             [0.49294975, 0.36862767, 0.2809724 , 0.        , 0.        ,
              0.33000746, 0.4940635 , 0.02182353, 0.        , 0.        ],
             [0.        , 0.3759548 , 0.        , 0.23886062, 0.        ,
              0.        , 0.35155773, 0.        , 0.24100977, 0.15047333],
             [0.40354684, 0.        , 0.12827398, 0.        , 0.07742476,
              0.6966675 , 0.        , 0.00607699, 0.        , 0.59710497],
             [0.        , 0.5936287 , 0.        , 0.        , 0.        ,
              0.5815142 , 0.2761725 , 0.47168115, 0.        , 0.26667207],
             [0.06918865, 0.        , 0.43948644, 0. 

# A real testing model



In [8]:
import torch
from torchvision.models import resnet50
# Generates random input and targets data for the model, where `b` is
# batch size.

def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32),
    )

torch_inputs = generate_data(1)
torch_module = resnet50()
torch_module.eval()
torch_module = torch.jit.trace(torch_module, torch_inputs)
torch_outputs = [torch_module(*torch_inputs)]



In [9]:
from jaxonnxruntime.experimental import call_torch
import jax
jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)
jax_fn = jax.jit(jax_fn)
jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)
jax_outputs = jax_fn(jax_params, jax_inputs)



verbose: False, log level: Level.ERROR



In [10]:
from jaxonnxruntime.experimental.call_torch import CallTorchTestCase
test_case = CallTorchTestCase()
test_case.assert_allclose(jax.tree_map(call_torch.torch_tensor_to_np_array,torch_outputs), jax_outputs, rtol=1e-07, atol=1e-03)


In [11]:
%timeit _ = torch_module(*torch_inputs)

93.7 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
%timeit _ = jax_fn(jax_params, jax_inputs)


258 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
