<a href="https://colab.research.google.com/github/hsudhakaran/test_jax/blob/main/Jax_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Tue Nov 25 20:30:45 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   39C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import jax.numpy as jnp
import jax

In [3]:
def f(x):
    return 4*x**3 + 3*x**2 + 2*x + 1
jax.make_jaxpr(f)(2.0)

{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = integer_pow[y=3] a
    c[35m:f32[][39m = mul 4.0:f32[] b
    d[35m:f32[][39m = integer_pow[y=2] a
    e[35m:f32[][39m = mul 3.0:f32[] d
    f[35m:f32[][39m = add c e
    g[35m:f32[][39m = mul 2.0:f32[] a
    h[35m:f32[][39m = add f g
    i[35m:f32[][39m = add h 1.0:f32[]
  [34;1min [39;22m(i,) }

In [4]:
grad_f = jax.grad(f)
jax.make_jaxpr(grad_f)(2.0)

{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = integer_pow[y=3] a
    c[35m:f32[][39m = integer_pow[y=2] a
    d[35m:f32[][39m = mul 3.0:f32[] c
    e[35m:f32[][39m = mul 4.0:f32[] b
    f[35m:f32[][39m = integer_pow[y=2] a
    g[35m:f32[][39m = integer_pow[y=1] a
    h[35m:f32[][39m = mul 2.0:f32[] g
    i[35m:f32[][39m = mul 3.0:f32[] f
    j[35m:f32[][39m = add e i
    k[35m:f32[][39m = mul 2.0:f32[] a
    l[35m:f32[][39m = add j k
    _[35m:f32[][39m = add l 1.0:f32[]
    m[35m:f32[][39m = mul 2.0:f32[] 1.0:f32[]
    n[35m:f32[][39m = mul 3.0:f32[] 1.0:f32[]
    o[35m:f32[][39m = mul n h
    p[35m:f32[][39m = add_any m o
    q[35m:f32[][39m = mul 4.0:f32[] 1.0:f32[]
    r[35m:f32[][39m = mul q d
    s[35m:f32[][39m = add_any p r
  [34;1min [39;22m(s,) }

In [5]:
jax.grad(f)(2.0)

Array(62., dtype=float32, weak_type=True)

In [6]:
def matrix_mul(a, b):
    return jnp.matmul(a, b)
key = jax.random.PRNGKey(42)
a = jax.random.normal(key, shape=(1000, 5000))
b = jax.random.normal(key, shape=(5000, 1000))
jax.make_jaxpr(matrix_mul)(a, b)

{ [34;1mlambda [39;22m; a[35m:f32[1000,5000][39m b[35m:f32[5000,1000][39m. [34;1mlet
    [39;22mc[35m:f32[1000,1000][39m = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a b
  [34;1min [39;22m(c,) }

In [7]:
# Normal computation
%timeit -n5 matrix_mul(a, b).block_until_ready()

The slowest run took 35.74 times longer than the fastest. This could mean that an intermediate result is being cached.
14.3 ms ± 28.1 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [8]:
jit_matrix_mul = jax.jit(matrix_mul)
jax.make_jaxpr(jit_matrix_mul)(a, b)

{ [34;1mlambda [39;22m; a[35m:f32[1000,5000][39m b[35m:f32[5000,1000][39m. [34;1mlet
    [39;22mc[35m:f32[1000,1000][39m = jit[
      name=matrix_mul
      jaxpr={ [34;1mlambda [39;22m; a[35m:f32[1000,5000][39m b[35m:f32[5000,1000][39m. [34;1mlet
          [39;22mc[35m:f32[1000,1000][39m = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] a b
        [34;1min [39;22m(c,) }
    ] a b
  [34;1min [39;22m(c,) }

In [9]:
# warmup
warmup_results = jit_matrix_mul(a, b)
# ⚡️ speed em up!
%timeit -n5 jit_matrix_mul(a, b).block_until_ready()

2.12 ms ± 278 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [10]:
def f_def(x):
    return x*x

g_def = jax.vmap(f_def)
x_test = jnp.array([2,4,6])
%timeit -n5 g_def(a).block_until_ready()
jitted_g = jax.jit(g_def)
jitted_g(b)
%timeit -n5 jitted_g(a).block_until_ready()

The slowest run took 12.04 times longer than the fastest. This could mean that an intermediate result is being cached.
1.97 ms ± 2.86 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
The slowest run took 24.17 times longer than the fastest. This could mean that an intermediate result is being cached.
1.37 ms ± 2.56 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [11]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.devices()

RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

In [12]:
from jax.experimental import mesh_utils
mesh = mesh_utils.create_device_mesh((2, 2))
mesh

ValueError: Number of devices 1 must equal the product of mesh_shape (2, 2)

In [13]:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8, 8))
x

Array([[ 1.6226422 ,  2.0252647 , -0.43359444, -0.07861735,  0.1760909 ,
        -0.97208923, -0.49529874,  0.4943786 ],
       [ 0.6643493 , -0.9501635 ,  2.1795304 , -1.9551506 ,  0.35857072,
         0.15779513,  1.2770847 ,  1.5104648 ],
       [ 0.970656  ,  0.59960806,  0.0247007 , -1.9164772 , -1.8593491 ,
         1.728144  ,  0.04719035,  0.814128  ],
       [ 0.13132767,  0.28284705,  1.2435943 ,  0.6902801 , -0.80073744,
        -0.74099   , -1.5388287 ,  0.30269185],
       [-0.02071605,  0.11328721, -0.2206547 ,  0.07052256,  0.8532958 ,
        -0.8217738 , -0.01461421, -0.15046217],
       [-0.9001352 , -0.7590727 ,  0.33309513,  0.80924904,  0.04269255,
        -0.57767123, -0.41439894, -1.9412533 ],
       [ 1.3161184 ,  0.7542728 ,  0.16170931, -0.03483307, -1.3306409 ,
         0.39362028,  0.48259583,  0.80382955],
       [-0.6337168 ,  1.038756  , -0.74159133, -0.4299588 , -0.22510043,
        -0.51966715, -1.6692165 ,  0.67535436]], dtype=float32)