In [49]:
import sys
sys.path.append("../berries")

import nnn
import importlib
importlib.reload(nnn)

from nnn import rmsnorm, EPS
from random_utils import infinite_safe_keys

from jax.numpy import ones, mean, square, exp
from jax.lax import rsqrt
import jax.numpy as np
from jax.random import normal

seed = 0
key_gen = infinite_safe_keys(seed)

In [50]:
def rmsnorm_o(w, x, eps=eps):
    x_norm = x * rsqrt(mean(square(x)) + eps)
    return x_norm * w

n_dim = 10
z = normal(next(key_gen).get(), (n_dim,))
w = np.ones(n_dim)
print(z.sum())

print(rmsnorm(w, z))
print(rmsnorm_o(w, z))

print((rmsnorm(w, z) ** 2).sum())
print((rmsnorm_o(w, z) ** 2).sum())


-6.314905
[-0.03467153 -0.00400841 -0.18247792  0.00708598  0.02979168  0.07110694
 -0.12874305 -0.15123652 -0.03338213 -0.13757929]
[-0.34671533 -0.0400841  -1.8247792   0.0708598   0.2979168   0.7110694
 -1.2874305  -1.5123651  -0.33382133 -1.375793  ]
0.10000001
10.000001


In [57]:
import jax
import jax.numpy as jnp
from jax import random, grad
import time
linalg = jnp.linalg

def rmsnorm_original(x, weight, eps=1e-8):
    mean_squared = jnp.mean(x ** 2)
    return x * weight * jax.lax.rsqrt(mean_squared + eps)

def rmsnorm_linalg(x, weight, eps=1e-8):
    norm = linalg.norm(x) / jnp.sqrt(x.size)
    return weight * x / (norm + eps)

# Benchmark function
def benchmark(func, x, weight, n_runs=100000):
    @jax.jit
    def run(x, weight):
        return func(x, weight)
    
    # Compile
    _ = run(x, weight)
    
    start = time.time()
    for _ in range(n_runs):
        _ = run(x, weight).block_until_ready()
    end = time.time()
    return (end - start) / n_runs

# Stability test function
def stability_test(func, x, weight, n_runs=10000):
    grad_func = jax.jit(grad(lambda x, w: jnp.sum(func(x, w))))
    results = []
    for _ in range(n_runs):
        x = random.normal(random.PRNGKey(_), x.shape) * 1e6  # Large inputs
        results.append(grad_func(x, weight))
    return jnp.std(jnp.stack(results))

# Run tests
d_model = 512
x = random.normal(random.PRNGKey(0), (d_model,)) * random.normal(random.PRNGKey(0), (d_model,))
weight = jnp.ones((d_model,))

print("Efficiency (average runtime in seconds):")
print(f"Original: {benchmark(rmsnorm_original, x, weight):.6f}")
print(f"Linalg:   {benchmark(rmsnorm_linalg, x, weight):.6f}")

print("\nNumerical Stability (standard deviation of gradients):")
print(f"Original: {stability_test(rmsnorm_original, x, weight):.6e}")
print(f"Linalg:   {stability_test(rmsnorm_linalg, x, weight):.6e}")

Efficiency (average runtime in seconds):
Original: 0.000052
Linalg:   0.000054

Numerical Stability (standard deviation of gradients):
Original: 5.436819e-08
Linalg:   5.436819e-08
