In [None]:
# GPU/CPU variant: try installing jax (may work on GPU on Colab)
!pip -q install --upgrade pip
!pip -q install jax jaxlib  # may install CPU-only; if GPU wheel needed, see JAX install docs


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m1.7/1.8 MB[0m [31m51.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import jax, jax.numpy as jnp, time
print("JAX version:", jax.__version__)
print("Devices:", jax.devices())


JAX version: 0.7.2
Devices: [CudaDevice(id=0)]


In [None]:
import numpy as np
key = jax.random.PRNGKey(0)

def init_params(in_dim, out_dim, key):
    return {
        'w': jax.random.normal(key, (in_dim, out_dim)) * 0.02,
        'b': jnp.zeros((out_dim,))
    }

def forward(params, x):
    return jnp.dot(x, params['w']) + params['b']

@jax.jit
def step(params, x, y):
    def loss_fn(p):
        pred = forward(p, x)
        return jnp.mean((pred - y)**2)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    return loss, grads

# Data
x = jax.random.normal(key, (128, 128))
y = jax.random.normal(key, (128, 10))
params = init_params(128, 10, key)

# Warmup & measure
_ = step(params, x, y)  # compile
t0 = time.time()
for _ in range(50):
    l,g = step(params, x, y)
t1 = time.time()
print("JIT time per step (ms):", (t1-t0)/50*1000)


JIT time per step (ms): 0.2042865753173828


above is project4