In [1]:
!pip3 install -U jax

Collecting jax
  Downloading jax-0.6.2-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.6.2,>=0.6.2 (from jax)
  Downloading jaxlib-0.6.2-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.3 kB)
Collecting ml_dtypes>=0.5.0 (from jax)
  Downloading ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Downloading jax-0.6.2-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.6.2-cp311-cp311-manylinux2014_x86_64.whl (89.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ml_dtypes, jaxlib, jax
  Attempting uni

In [7]:
import torch

## Basic array operations with jax.numpy

### JAX's jax.numpy is nearly identical to NumPy's but runs on accelerators

In [6]:
import jax.numpy as jnp

# creating arrays
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])

# basic operations
z = x + y
print(z)

[5. 7. 9.]


## Automatic Differentiation with jax.grad
### JAX can compute gradients of Python functions

In [10]:
from jax import grad

# defining a simple function: f(x) = x^2 + 2x + 1

def f(x):
  return x**2 + 2*x + 1

# compute the derivative: f(x) = 2x + 2
df_dx = grad(f)
print(df_dx(3.0)) # 2 * 3 + 2 = 8.0

8.0


## JIT Complilation with jax.jit
### Speed up functions by compiling them

In [11]:
from jax import jit

# JIT-compile the function
f_jit = jit(f)

print(f_jit(3.0))

16.0


## Vectorization with jax.vmap
### Apply a function over array elements without loops


In [12]:
from jax import vmap

x = jnp.array([1.0, 2.0, 3.0])
f_vectorized = vmap(f)
print(f_vectorized(x))


[ 4.  9. 16.]


## Random Numbers
### JAX requires explicit random numbers seeds for reproducibility

In [13]:
from jax import random

# generating a random key
key = random.PRNGKey(42)

# generating random numbers
x = random.normal(key, (3,))
print(x)

[-0.02830462  0.46713185  0.29570296]


## Linear Regressions with Gradient Descent


In [22]:
import jax.numpy as jnp
from jax import grad, jit, random

# Generate synthetic data
key = random.PRNGKey(0)
X = random.normal(key, (100, 1))
y = 2 * X[:, 0] + 1 + 0.1 * random.normal(key, (100,))

# Define model and loss
def model(params, X):
    return params[0] * X + params[1]  # Linear model: slope * X + intercept

def loss(params, X, y):
    predictions = model(params, X)
    return jnp.mean((predictions - y) ** 2)

# Gradient descent
@jit
def update(params, X, y, lr=0.2):  # Increased learning rate
    gradients = grad(loss)(params, X, y)
    return params - lr * gradients

# Initialize and train
params = jnp.array([0.0, 0.0])  # [slope, intercept]
for _ in range(1000):
    params = update(params, X, y)

print(f"Learned params: slope={params[0]:.2f}, intercept={params[1]:.2f}")

Learned params: slope=-0.00, intercept=1.23


In [21]:
print("X sample:", X[:5], "y sample:", y[:5])

X sample: [[ 1.6226422 ]
 [ 2.0252647 ]
 [-0.43359444]
 [-0.07861735]
 [ 0.1760909 ]] y sample: [4.4075484  5.253056   0.08945169 0.8349036  1.3697909 ]
