# Performance of JAX vs PyTorch

Observataions:
- it seems that JAX with JIT is fatser than Pytorch JIT.Script
- for my toy kernel using the where() method, JAX is 20-30x faster PyTorch
- Jax JIT can be used for class method as well. 
- torch jit script works only for functions and subclasses of nn.Module

An useful link https://www.kaggle.com/code/grez911/performance-of-jax-vs-pytorch

In [None]:
%load_ext memory_profiler

In [None]:
N = 10000

## Jax

In [None]:
%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
def jax_kernel(x):
    return jnp.sum(jnp.where(x > 0.5, x, x**2))

In [None]:
key = random.PRNGKey(0)
x = random.uniform(key, (N,))
# x

In [None]:
%timeit jax_kernel(x)

In [None]:
jit_jax_kernel = jit(jax_kernel)

%timeit jit_jax_kernel(x).block_until_ready()

In [None]:
derivative_jax_kernel = grad(jax_kernel)

%timeit derivative_jax_kernel(x).block_until_ready()

In [None]:
derivative_jit_jax_kernel = jit(derivative_jax_kernel)

%timeit derivative_jit_jax_kernel(x).block_until_ready()

In [None]:
%memit jit_jax_kernel(x)

In [None]:
%memit derivative_jit_jax_kernel(x)

## Torch

In [None]:
import torch

In [None]:
def torch_kernel(x):
    return torch.sum(torch.where(x > 0.5, x, x**2))

In [None]:
x = torch.rand((N,), requires_grad=True)
# x

In [None]:
%timeit torch_kernel(x)

In [None]:
jit_torch_kernel = torch.jit.script(torch_kernel)

%timeit jit_torch_kernel(x)

In [None]:
%timeit torch.autograd.grad(torch_kernel(x), [x])

In [None]:
%timeit torch.autograd.grad(jit_torch_kernel(x), [x])

In [None]:
%memit jit_torch_kernel(x)

In [None]:
%memit torch.autograd.grad(jit_torch_kernel(x), [x])

### Further experiments

#### Jax JIT

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
x = random.uniform(key, (1000,3))
x.dtype

In [None]:
def F(x, y):
    return x.sum() + 2*y.sum()

In [None]:
dF = jit(grad(F, argnums=1))

In [None]:
dF(x, x)

In [None]:
@jit
def F(x, y):
    sum_x  = x.sum()
    print(f'{x=}')
    sum_y = 2*y.sum()
    print(f'{sum_y=}')
    print(f'{x.shape=}')
    return sum_x + sum_y

In [None]:
F(x, x)

#### torch scriptted function

In [4]:
import torch
from torch import Tensor 

@torch.jit.script
def cutoff_function(r: Tensor, r_cutoff: float) -> Tensor:
    cfn = torch.tanh(1.0 - r / r_cutoff).pow(3)
    return torch.where(r < r_cutoff, cfn, torch.zeros_like(r))

r = torch.tensor(1.0) 
r_cutoff = 2.0
cutff_function(r, r_cutoff)

tensor(0.0987)

In [2]:
# scripted_call.graph

In [3]:
%timeit cutoff_function(r, r_cutoff)

25.1 µs ± 2.63 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
%timeit cutoff_function(r, r_cutoff)

5.03 µs ± 1.25 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


#### Cutoff function

In [28]:
import jax.numpy as jnp
import numpy as onp
from jax import jit, jacfwd, grad
from functools import partial

In [41]:
_LIB = jnp # torch
    
class CutoffFunction:

    _TANH_PRE = ((torch.e + 1 / torch.e) / (torch.e - 1 / torch.e))**3

    def __init__(self, r_cutoff: float, cutoff_type: str = "tanh"):
        self.r_cutoff = r_cutoff
        self.cutoff_type = cutoff_type.lower()
        self.inv_r_cutoff = 1.0 / self.r_cutoff
        # Set cutoff type function
        try:
            self.cfn = getattr(self, f"{self.cutoff_type}")
        except AttributeError:
            logger.error(
                f"'{self.__class__.__name__}' has no cutoff function '{self.cutoff_type}'",
                exception=NotImplementedError,
            )

    @partial(jit, static_argnums=(0,))
    def __call__(self, r: Tensor) -> Tensor:
        return _LIB.where(r < self.r_cutoff, self.cfn(r), _LIB.zeros_like(r))

    def hard(self, r: Tensor) -> Tensor:
        return _LIB.ones_like(r)

    def tanhu(self, r: Tensor) -> Tensor:
        return _LIB.tanh(1.0 - r * self.inv_r_cutoff)**3

    def tanh(self, r: Tensor) -> Tensor:
        return self._TANH_PRE * _LIB.tanh(1.0 - r * self.inv_r_cutoff)**3

    def cos(self, r: Tensor) -> Tensor:
        return 0.5 * (_LIB.cos(math.pi * r * self.inv_r_cutoff) + 1.0)

    def exp(self, r: Tensor) -> Tensor:
        return _LIB.exp(1.0 - 1.0 / (1.0 - (r * self.inv_r_cutoff) ** 2))

    def poly1(self, r: Tensor) -> Tensor:
        return (2.0 * r - 3.0) * r**2 + 1.0

    def poly2(self, r: Tensor) -> Tensor:
        return ((15.0 - 6.0 * r) * r - 10) * r**3 + 1.0

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(r_cutoff={self.r_cutoff}, cutoff_type='{self.cutoff_type}')"


In [42]:
r = jnp.array(1.0) # torch.tensor(1.0)
r_cutoff = 2.0

cfn = CutoffFunction(r_cutoff)
cfn

CutoffFunction(r_cutoff=2.0, cutoff_type='tanh')

In [39]:
%timeit cfn(r)

31.9 µs ± 6.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [43]:
%timeit cfn(r)

1.83 µs ± 55.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [33]:
cfn(r)

DeviceArray(0.2234012, dtype=float32, weak_type=True)

In [35]:
cfn(r*0.5)

DeviceArray(0.5800375, dtype=float32, weak_type=True)