In [13]:
import jax
import jax.numpy as jnp
import jax.random as jr
import opt_einsum as oe

from einshape import jax_einshape as einshape

Array = jax.Array

In [14]:
def block(x):
    return jax.tree.map(lambda x: x.block_until_ready(), x)

In [15]:
b = m = 256
n = b * b
d = 64

L = jr.normal(jr.key(0), (b, m, m))
R = jr.normal(jr.key(1), (m, b, b))

Q = jr.normal(jr.key(2), (n, d))
K = jr.normal(jr.key(3), (n, d))

In [16]:
def monarch_quadratic_2(L: Array, R: Array):
    left = jnp.matmul(L.mT, L)
    right = jnp.matmul(R, R.mT)
    left = jax.vmap(jnp.diag)(left)
    right = jax.vmap(jnp.diag)(right)
    return jnp.sum(left * right.T)

def monarch_quadratic(L: Array, R: Array):
    Lhat = jnp.sum(L**2, axis=-2)
    Rhat = jnp.sum(R**2, axis=-1)
    return 1/2 * jnp.sum(Lhat * Rhat.T)

def grad_monarch_quadratic(L: Array, R: Array):
    Lhat = jnp.sum(L**2, axis=-2)
    Rhat = jnp.sum(R**2, axis=-1)
    dL = oe.contract("kj,jlk->jlk", Rhat, L)
    dR = oe.contract("jk,kji->kji", Lhat, R)
    return dL, dR

print(monarch_quadratic(L, R))
print(monarch_quadratic_2(L, R))
# f1 = jax.jit(jax.grad(monarch_quadratic, argnums=(0, 1)))
# f2 = jax.jit(grad_monarch_quadratic)

# f1(L, R), f2(L, R)

# %timeit block(f1(L, R))
# %timeit block(f2(L, R))
# jax.tree.map(lambda x, y: ;jnp.allclose(x, y), g1, g2)

2148796000.0
4297592000.0


In [9]:
def monarch_linear(L: Array, R: Array, Q: Array, K: Array):
    Q2d = einshape("(mb)d->mbd", Q, b=b)
    K2d = einshape("(mb)d->mbd", K, b=b)
    return -oe.contract("jlk,kji,lja,kia->", L, R, Q2d, K2d)

def grad_monarch_linear(L: Array, R: Array, Q: Array, K: Array):
    Q2d = einshape("(mb)d->mbd", Q, b=b)
    K2d = einshape("(mb)d->mbd", K, b=b)
    dL = -oe.contract("kji,lja,kia->jlk", R, Q2d, K2d)
    dR = -oe.contract("jlk,lja,kia->kji", L, Q2d, K2d)
    return dL, dR

f1 = jax.jit(jax.grad(monarch_linear, argnums=(0, 1)))
f2 = jax.jit(grad_monarch_linear)

f1(L, R, Q, K), f2(L, R, Q, K)

%timeit block(f1(L, R, Q, K))
%timeit block(f2(L, R, Q, K))

66.2 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
66.1 ms ± 1.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
monarch_linear_original(L, R, Q, K)

Array(420566.6, dtype=float32)