In [None]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [None]:
!pip install --upgrade jax

Collecting jax
[?25l  Downloading https://files.pythonhosted.org/packages/50/f4/d90107c22334c267ccb64e0ea8039018a4740b5dfad1576dd868aac45254/jax-0.1.59.tar.gz (270kB)
[K     |█▏                              | 10kB 26.8MB/s eta 0:00:01[K     |██▍                             | 20kB 6.3MB/s eta 0:00:01[K     |███▋                            | 30kB 8.8MB/s eta 0:00:01[K     |████▉                           | 40kB 5.8MB/s eta 0:00:01[K     |██████                          | 51kB 4.8MB/s eta 0:00:01[K     |███████▎                        | 61kB 5.7MB/s eta 0:00:01[K     |████████▌                       | 71kB 6.2MB/s eta 0:00:01[K     |█████████▊                      | 81kB 7.0MB/s eta 0:00:01[K     |███████████                     | 92kB 7.8MB/s eta 0:00:01[K     |████████████▏                   | 102kB 7.7MB/s eta 0:00:01[K     |█████████████▎                  | 112kB 7.7MB/s eta 0:00:01[K     |██████████████▌                 | 122kB 7.7MB/s eta 0:00:01[K     |██

# JAX 1. Numpy Wrapper

In [None]:
import numpy as np

x = np.ones((5000, 5000))
y = np.arange(5000)

%timeit z = np.sin(x) + np.cos(y)

1 loop, best of 3: 401 ms per loop


In [None]:
import jax.numpy as jnp
x = jnp.ones((5000, 5000))
y = jnp.arange(5000)

%timeit z = jnp.sin(x) + jnp.cos(y)

100 loops, best of 3: 2.15 ms per loop


# JAX 2. JIT Compiler

In [None]:
from jax import jit
import tensorflow as tf

def fn(x, y):
  z = np.sin(x)
  w = np.cos(y)
  return z + w

@jit
def fn_jit(x, y):
  z = jnp.sin(x)
  w = jnp.cos(y)
  return z + w

@tf.function
def fn_tf2(x, y):
  z = tf.sin(x)
  w = tf.cos(y)
  return z + w

In [None]:
x = np.ones((5000, 5000))
y = np.ones((5000, 5000))
%timeit fn(x, y)

1 loop, best of 3: 780 ms per loop


In [None]:
jx = jnp.ones((5000, 5000))
jy = jnp.ones((5000, 5000))
%timeit fn_jit(jx, jy)

100 loops, best of 3: 2.12 ms per loop


In [None]:
tx = tf.ones((5000, 5000))
ty = tf.ones((5000, 5000))
%timeit fn_tf2(tx, ty)

The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 3.36 ms per loop


# JAX 3. grad

In [None]:
from jax import grad

@jit
def simple_fun(x):
  return jnp.sin(x) / x

In [None]:
grad_simple_fun = grad(simple_fun)

In [None]:
%timeit grad_simple_fun(1.0)

1000 loops, best of 3: 1.22 ms per loop


In [None]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.30116874, dtype=float32),
 DeviceArray(-0.43539774, dtype=float32),
 DeviceArray(-0.3456775, dtype=float32),
 DeviceArray(-0.11611074, dtype=float32),
 DeviceArray(0.09508941, dtype=float32),
 DeviceArray(0.16778992, dtype=float32),
 DeviceArray(0.09429243, dtype=float32),
 DeviceArray(-0.03364623, dtype=float32),
 DeviceArray(-0.10632458, dtype=float32)]

In [None]:
grad_grad_simple_fun = grad(grad(simple_fun))

In [None]:
%timeit grad_grad_simple_fun(1.0)

The slowest run took 93.35 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 3.19 ms per loop


In [None]:
grad_grad_simple_fun(1.0)

DeviceArray(-0.23913354, dtype=float32)

In [None]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.23913354, dtype=float32),
 DeviceArray(-0.01925094, dtype=float32),
 DeviceArray(0.18341166, dtype=float32),
 DeviceArray(0.247256, dtype=float32),
 DeviceArray(0.1537491, dtype=float32),
 DeviceArray(-0.00936072, dtype=float32),
 DeviceArray(-0.12079593, dtype=float32),
 DeviceArray(-0.11525822, dtype=float32),
 DeviceArray(-0.02216326, dtype=float32)]

In [22]:
from einops import rearrange
import torch
import numpy as np
import torch.nn.functional as F

In [23]:
def segsum(x):
    """Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
       which is equivalent to a scalar SSM."""
    T = x.size(-1)
    x_cumsum = torch.cumsum(x, dim=-1)
    x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum

def ssd(X, A, B, C, block_len=64, initial_states=None):
    """
    Arguments:
        X: (batch, length, n_heads, d_head)
        A: (batch, length, n_heads)
        B: (batch, length, n_heads, d_state)
        C: (batch, length, n_heads, d_state)
    Return:
        Y: (batch, length, n_heads, d_head)
    """
    assert X.dtype == A.dtype == B.dtype == C.dtype
    assert X.shape[1] % block_len == 0

    # Rearrange into blocks/chunks
    X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]

    A = rearrange(A, "b c l h -> b h c l")
    A_cumsum = torch.cumsum(A, dim=-1)

    # 1. Compute the output for each intra-chunk (diagonal blocks)
    L = torch.exp(segsum(A))
    Y_diag  = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)

    # 2. Compute the state for each intra-chunk
    # (right term of low-rank factorization of off-diagonal blocks; B terms)
    decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
    states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)

    # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
    # (middle term of factorization of off-diag blocks; A terms)
    if initial_states is None:
        initial_states = torch.zeros_like(states[:, :1])
    states = torch.cat([initial_states, states], dim=1)
    decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
    new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
    states, final_state = new_states[:, :-1], new_states[:, -1]

    # 4. Compute state -> output conversion per chunk
    # (left term of low-rank factorization of off-diagonal blocks; C terms)
    state_decay_out = torch.exp(A_cumsum)
    Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)

    # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
    Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
    return Y, final_state

In [32]:
batch_size = 2
len = 100
d_head = 10
d_state = 200
d_total = 20
num_heads = d_total // d_head
num_chunks = 5


B, C = [np.random.randn(batch_size, len, num_heads, d_state) for index in range(2)]
A = np.random.randn(batch_size, len, num_heads)
X = np.random.randn(batch_size, len, num_heads, d_head) # this code is invariant to d_head to some extent is the funny thing
X, A, B, C = list(map(torch.from_numpy, [X, A, B, C]))


In [33]:
y = ssd(X, A, B, C, block_len=20)
print(y[0].shape)

assert y[0].shape == (batch_size, len, num_heads, d_head)

torch.Size([2, 100, 2, 10])
