In [1]:
import jax.random as jr
import jax.numpy as jnp
import jax

In [2]:
from config import Config
from model import init_params


def params():
    """Initialize model parameters."""
    key =jr.PRNGKey(42)
    return init_params(key)

def sample_batch():
    """Create a sample batch of data."""
    key =jr.PRNGKey(42)
    batch_size = 4
    ctx_bits = jr.randint(key, (batch_size, Config.L), 0, Config.vocab_size)
    labels = jr.randint(key, (batch_size,), 0, Config.vocab_size)
    return ctx_bits, labels

In [3]:
from model import loss_fn


ctx_bits, labels = sample_batch()
force_weight = jnp.array(0.1)

grad_fn = jax.grad(loss_fn)
grads = grad_fn(params(), ctx_bits, labels, force_weight)

for key, grad in grads.items():
    assert jnp.all(jnp.isfinite(grad)), \
        f"Gradient for {key} contains non-finite values"



I0000 00:00:1765657031.411929 11183230 service.cc:145] XLA service 0x13a6bf020 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1765657031.411946 11183230 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1765657031.413550 11183230 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1765657031.413563 11183230 mps_client.cc:384] XLA backend will use up to 22906109952 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M1 Max

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



In [None]:
from utils import save_params


# save_params(closure_grads, "data/closure_grads.npz")
# save_params(grads, "data/grads.npz")
# save_params(grads, "data/jitted_grads.npz")
# save_params(grads, "data/jitted_closure_grads.npz")

In [5]:
from utils import load_params


closure_grads = load_params("data/jitted_grads.npz")
grads = load_params("data/jitted_closure_grads.npz")


In [6]:
# Compare grads and closure_grads
print("Comparing grads and closure_grads...")
print(f"Keys match: {set(grads.keys()) == set(closure_grads.keys())}")

all_close = True
max_diff = 0.0
max_diff_key = None

for key in grads.keys():
    if key not in closure_grads:
        print(f"Key {key} missing in closure_grads")
        all_close = False
        continue
    
    grad1 = grads[key]
    grad2 = closure_grads[key]
    
    # Check shapes match
    if grad1.shape != grad2.shape:
        print(f"Shape mismatch for {key}: {grad1.shape} vs {grad2.shape}")
        all_close = False
        continue
    
    # Check if they're close
    close = jnp.allclose(grad1, grad2, rtol=1e-5, atol=1e-7)
    max_abs_diff = jnp.max(jnp.abs(grad1 - grad2))
    
    if not close:
        print(f"{key}: NOT close (max diff: {max_abs_diff:.2e})")
        all_close = False
    else:
        print(f"{key}: ✓ close (max diff: {max_abs_diff:.2e})")
    
    if max_abs_diff > max_diff:
        max_diff = max_abs_diff
        max_diff_key = key

print(f"\nOverall: {'✓ All gradients match!' if all_close else '✗ Some gradients differ'}")
print(f"Maximum difference: {max_diff:.2e} in {max_diff_key}")

Comparing grads and closure_grads...
Keys match: True
W_dec: ✓ close (max diff: 0.00e+00)
a: ✓ close (max diff: 0.00e+00)
b: ✓ close (max diff: 0.00e+00)
b_dec: ✓ close (max diff: 0.00e+00)
c: ✓ close (max diff: 0.00e+00)
xi_attn_embed_raw: ✓ close (max diff: 3.26e-09)
xi_hopf_raw: ✓ close (max diff: 3.64e-12)

Overall: ✓ All gradients match!
Maximum difference: 3.26e-09 in xi_attn_embed_raw
