<a href="https://colab.research.google.com/github/google/jaxonnxruntime/blob/call_torch/docs/experimental_call_torch_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline


# jaxonnxruntime call_torch Tutorial
**Author:** John Zhang


Here we introduce the call_torch API which can seamlessly translate PyTorch models into JAX functions. This integration unites PyTorch with the extensive JAX software ecosystem and harnesses the power of XLA hardware (TPU/GPU/CPU and openXLA ), enhancing cross-framework collaboration and performance potential




In [None]:
!pip install git+https://github.com/google/jaxonnxruntime.git


In [None]:
!pip install onnx torch

## Basic Usage

Generally, we describe all models with format.  We use JAX PyTree data structure for any type model parameters and model inputs.Broadly, our approach involves characterizing all models using a standardized format. This entails employing the JAX PyTree data structure to encapsulate model parameters and inputs of varying types.


In [None]:
import torch
import jax
from jaxonnxruntime.experimental import call_torch

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

torch_inputs =(torch.randn(10, 10), torch.randn(10, 10))
torch_module = torch.jit.trace(foo, torch_inputs)

print("torch_output: ", torch_module(*torch_inputs))


In [None]:
jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)
jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)
print("jax_output:", jax_fn(jax_params, jax_inputs))

*We* can also take ``torch.nn.Module``.



In [None]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

torch_module = MyModule()
torch_inputs = (torch.randn(10, 100), )

In [None]:
torch_module.eval()
jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)
jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)
print("jax_output:", jax_fn(jax_params, jax_inputs))

# A real testing model



In [None]:
import torch
from torchvision.models import resnet50
# Generates random input and targets data for the model, where `b` is
# batch size.

def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32),
    )

torch_inputs = generate_data(1)
torch_module = resnet50()
torch_module.eval()
torch_module = torch.jit.trace(torch_module, torch_inputs)
torch_outputs = [torch_module(*torch_inputs)]



In [None]:
from jaxonnxruntime.experimental import call_torch
import jax
jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)
jax_fn = jax.jit(jax_fn)
jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)
jax_outputs = jax_fn(jax_params, jax_inputs)

In [None]:
from jaxonnxruntime.experimental.call_torch import CallTorchTestCase
test_case = CallTorchTestCase()
test_case.assert_allclose(jax.tree_map(call_torch.torch_tensor_to_np_array,torch_outputs), jax_outputs, rtol=1e-07, atol=1e-03)


In [None]:
%timeit _ = torch_module(*torch_inputs)

In [None]:
%timeit _ = jax_fn(jax_params, jax_inputs)
