In [8]:
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

# X: (batch, length, n_heads, d_head)
# A: (batch, length, n_heads)
# B: (batch, length, n_heads, d_state)
# C: (batch, length, n_heads, d_state)
batch, length, n_heads, d_head = 3, 128, 4, 256
block_len = 64
d_state = 64
initial_states = None
X = torch.randn((batch, length, n_heads, d_head), dtype=torch.float32)
A = torch.randn((batch, length, n_heads), dtype=torch.float32)
B = torch.randn((batch, length, n_heads, d_state), dtype=torch.float32)
C = torch.randn((batch, length, n_heads, d_state), dtype=torch.float32)

# Rearrange into blocks/chunks
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
print("X shape: ", X.shape)
print("B shape: ", B.shape)
print("C shape: ", C.shape)

A = rearrange(A, "b c l h -> b h c l")
print("A shape: ", A.shape)

X shape:  torch.Size([3, 2, 64, 4, 256])
B shape:  torch.Size([3, 2, 64, 4, 64])
C shape:  torch.Size([3, 2, 64, 4, 64])
A shape:  torch.Size([3, 4, 2, 64])


In [9]:
A_cumsum = torch.cumsum(A, dim=-1)
print("A_cumsum shape: ", A_cumsum.shape)
print(A_cumsum[0,0])

A_cumsum shape:  torch.Size([3, 4, 2, 64])
tensor([[-0.2658,  0.2051, -0.4677,  0.1695,  1.1255,  0.6326,  2.0853,  1.0546,
         -0.0799, -1.0708, -1.2451, -1.1371, -1.5578, -3.3396, -3.1794, -2.2924,
         -1.1573, -0.0832,  0.2795, -0.0132, -0.5667, -0.5211, -0.5288, -2.2874,
         -1.0053, -0.3439,  1.1854,  0.6855, -0.0088, -1.6028, -2.5258, -2.5608,
         -2.8232, -3.5375, -3.0490, -2.8024, -1.8804, -1.6589, -0.3597, -3.0049,
         -3.0880, -1.7377, -0.7670, -1.5339, -2.3543, -2.1628, -1.9823, -0.9725,
         -0.2307, -0.5601, -3.2543, -0.9201,  0.1051, -0.1104, -0.4800, -0.4885,
          0.2217, -1.1623, -0.6456, -0.5058, -0.3892,  0.5867, -0.6275, -2.5660],
        [-0.2748,  0.2375,  0.3691,  1.5900,  1.8836,  1.0119, -1.2858, -4.3879,
         -3.6671, -2.6230, -0.4432,  1.3762,  1.2834,  1.2889,  1.9674,  2.2748,
          1.0832,  0.5214,  0.3608,  0.5663, -1.5273, -0.8515,  0.4281,  0.8482,
         -1.3273, -2.0056, -1.8443, -2.0875, -1.4093, -2.3229, -2

In [10]:
def segsum(x):
    """More stable segment sum calculation."""
    T = x.size(-1)
    x = repeat(x, "... d -> ... d e", e=T)
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
    x = x.masked_fill(~mask, 0)
    x_segsum = torch.cumsum(x, dim=-2)
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum

In [14]:
 # 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag  = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)  ## orange
print("L shape: ", L.shape)
print("Y_diag shape: ", Y_diag.shape)
print(L[0,0,0])

L shape:  torch.Size([3, 4, 2, 64, 64])
Y_diag shape:  torch.Size([3, 2, 64, 4, 256])
tensor([[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.6014, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.8172, 0.5103, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [2.3456, 1.4647, 2.8702,  ..., 1.0000, 0.0000, 0.0000],
        [0.6965, 0.4349, 0.8523,  ..., 0.2969, 1.0000, 0.0000],
        [0.1002, 0.0626, 0.1227,  ..., 0.0427, 0.1439, 1.0000]])


In [18]:
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)  ## green
print("decay_states shape: ", decay_states.shape)
print("states shape: ", states.shape)

decay_states shape:  torch.Size([3, 4, 2, 64])
states shape:  torch.Size([3, 2, 4, 256, 64])


In [19]:
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
    initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)  ## yellow
states, final_state = new_states[:, :-1], new_states[:, -1]
print("states shape: ", states.shape)
print("final_state shape: ", final_state.shape)

states shape:  torch.Size([3, 2, 4, 256, 64])
final_state shape:  torch.Size([3, 4, 256, 64])


In [20]:
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)   ## blue
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)

# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")

print("state_decay_out shape: ", state_decay_out.shape)
print("Y_off shape: ", Y_off.shape)
print("Y Shape: ", Y.shape)

state_decay_out shape:  torch.Size([3, 4, 2, 64])
Y_off shape:  torch.Size([3, 2, 64, 4, 256])
Y Shape:  torch.Size([3, 128, 4, 256])
