In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_parallel import vlstm_fw_torch
from vlstm_recurrent import vlstm_recurrent_sequence_stabilized
from vlstm_chunkwise_parallel import vlstm_chunkwise_parallel_3
from einops import rearrange
from torch.nn import functional as F


# Match vLSTM chunkwise parallel to parallel


In [2]:
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 12
NH = 1
DH = 3
EPS = 0.0

In [13]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [14]:
# igate_preacts = 5 * torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) / 10000
# fgate_preacts = torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) +1 # / 100

In [15]:
# igate_preacts = 5 * torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) / 10000
# fgate_preacts = torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) / 100

In [16]:
# fgate_preacts, igate_preacts

In [17]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 12, 3])

In [18]:
y_p = vlstm_fw_torch(qs, ks, vs, igate_preacts, fgate_preacts, eps=EPS)

In [19]:
# y_r = vlstm_recurrent_sequence_stabilized(qs, ks, vs, igate_preacts, fgate_preacts, normalization_mode="max_abs_sum_C_1", eps=EPS)
# y_r, torch.allclose(y_p, y_r, atol=1e-5)

In [20]:
y_p

tensor([[[[-0.4766, -0.5826, -0.9475],
          [ 0.3189, -0.5805, -0.3735],
          [ 0.0083,  1.1414,  1.1899],
          [-0.4672, -0.2603, -0.0403],
          [-0.1332,  0.4042,  0.0817],
          [ 0.6912, -0.1775, -0.1799],
          [-1.6277, -1.1145,  0.6377],
          [-0.4791, -0.3256,  0.1318],
          [-0.6228,  0.2321,  0.4485],
          [-0.1759,  0.1434,  0.1479],
          [ 0.1849, -0.0911, -0.0084],
          [-0.7768, -0.2171,  0.1132]]]], device='cuda:0')

In [21]:
y_cp = vlstm_chunkwise_parallel_3(qs, ks, vs, igate_preacts, fgate_preacts, chunk_size=4)
y_cp

log_fgates: torch.Size([1, 1, 3, 4])
tensor([[[[-0.5762, -1.3072, -0.1690, -0.9215],
          [-0.9936, -0.2652, -0.8875, -0.7972],
          [-0.5570, -0.9301, -0.8282, -1.4776]]]], device='cuda:0')
p_vec_f: torch.Size([1, 1, 3, 4])
tensor([[[[-0.5762, -1.8834, -2.0525, -2.9740],
          [-0.9936, -1.2587, -2.1462, -2.9435],
          [-0.5570, -1.4871, -2.3153, -3.7928]]]], device='cuda:0')
q_vec_f: torch.Size([1, 1, 3, 4])
tensor([[[[-2.9740, -2.3977, -1.0905, -0.9215],
          [-2.9435, -1.9499, -1.6847, -0.7972],
          [-3.7928, -3.2359, -2.3057, -1.4776]]]], device='cuda:0')
g_vec: torch.Size([1, 1, 3])
tensor([[[-2.9740, -2.9435, -3.7928]]], device='cuda:0')


tensor([[[[[-0.4766, -0.5826, -0.9475],
           [ 0.3189, -0.5805, -0.3735],
           [ 0.0083,  1.1414,  1.1899],
           [-0.4672, -0.2603, -0.0403]],

          [[-0.1608,  0.6145,  0.2053],
           [ 0.7484, -0.2315, -0.2424],
           [-1.6329, -1.1134,  0.6515],
           [-0.4936, -0.3449,  0.1181]],

          [[-0.6963,  0.2131,  0.4381],
           [-0.2021,  0.1161,  0.1755],
           [ 0.2011, -0.0778, -0.0165],
           [-0.7652, -0.2093,  0.1081]]]]], device='cuda:0')

In [13]:
queries = qs
keys = ks
values = vs
igate_preact = igate_preacts
fgate_preact = fgate_preacts
chunk_size = 4

In [22]:
B, NH, S, DH = queries.shape
_dtype, _device = queries.dtype, queries.device
qs = rearrange(queries, "b nh (nc l) dh -> b nh nc l dh", l=chunk_size) * (DH**-0.5)
ks = rearrange(keys, "b nh (nc l) dh -> b nh nc l dh", l=chunk_size)
vs = rearrange(values, "b nh (nc l) dh -> b nh nc l dh", l=chunk_size)
_, _, NC, L, _ = qs.shape
igs = rearrange(igate_preact, "b nh (nc l) 1 -> b nh nc l", l=chunk_size)
fgs = rearrange(fgate_preact, "b nh (nc l) 1 -> b nh nc l", l=chunk_size)

# compute the gates, the g and the p and q vectors
log_fgates = F.logsigmoid(fgs)

p_vec_f = torch.cat(
    [
        torch.zeros((B, NH, NC, 1), dtype=_dtype, device=_device),
        log_fgates[:, :, :, :-1].cumsum(-1),
    ],
    dim=-1,
)

q_vec_f_raw = torch.cat(
    [
        torch.zeros((B, NH, NC, 1), dtype=_dtype, device=_device),
        log_fgates[:, :, :, 1:].cumsum(-1),
    ],
    dim=-1,
)
q_vec_f = log_fgates[:, :, :, 1:].sum(-1, keepdim=True) - q_vec_f_raw

p_vec = p_vec_f + igs
q_vec = q_vec_f + igs
g_vec = log_fgates.sum(-1)

# get the maximum values per chunk for p and q
p_vec_max = p_vec.max(-1).values
q_vec_max = q_vec.max(-1).values


# loop 1: materialize the  C_k, n_k, m_k
C_states = torch.zeros((B, NH, NC, DH, DH), dtype=_dtype, device=_device)
n_states = torch.zeros((B, NH, NC, DH), dtype=_dtype, device=_device)
m_states = torch.zeros((B, NH, NC, 1), dtype=_dtype, device=_device)

m_k = torch.zeros((B, NH, 1), dtype=_dtype, device=_device)
m_prev_k = torch.zeros((B, NH, 1), dtype=_dtype, device=_device)
C_k = torch.zeros((B, NH, DH, DH), dtype=_dtype, device=_device)
C_prev_k = torch.zeros((B, NH, DH, DH), dtype=_dtype, device=_device)
n_k = torch.zeros((B, NH, DH), dtype=_dtype, device=_device)
n_prev_k = torch.zeros((B, NH, DH), dtype=_dtype, device=_device)
for k in range(1, NC):
    i = k - 1
    # m_k
    m_q_k = q_vec_max[:, :, i]
    g_k = g_vec[:, :, i]
    m_k = torch.max(g_k + m_prev_k, m_q_k)
    m_states[:, :, k] = m_k

    # C_k
    k_chunk = ks[:, :, i, :, :].clone()
    v_chunk = vs[:, :, i, :, :].clone()
    q_k = q_vec[:, :, i, :].clone()
    k_chunk_gated = k_chunk * torch.exp(q_k - m_k).unsqueeze(-1)

    C_k = (
        torch.exp(g_k + m_prev_k - m_k) * C_prev_k
        + k_chunk_gated.transpose(-2, -1) @ v_chunk
    )
    C_states[:, :, k] = C_k

    # n_k
    n_k = torch.exp(g_k + m_prev_k - m_k) * n_prev_k + k_chunk_gated.transpose(
        -2, -1
    ).sum(-1)
    n_states[:, :, k] = n_k

    # move to the next iteration
    m_prev_k = m_k
    C_prev_k = C_k
    n_prev_k = n_k

# loop 2: compute the H_states
H_states = torch.zeros((B, NH, NC, L, DH), dtype=_dtype, device=_device)
for k in range(1, NC + 1):
    i = k - 1

    # load C_k, n_k, m_k
    C_k = C_states[:, :, i]
    n_k_inter = n_states[:, :, i]
    m_k = m_states[:, :, i]
    # load q, k, v chunks
    q_chunk = qs[:, :, i, :, :].clone()
    k_chunk = ks[:, :, i, :, :].clone()
    v_chunk = vs[:, :, i, :, :].clone()

    # ? Compute inter chunk contribution: H_inter
    p_k = p_vec[:, :, i, :].clone()

    m_p_k = p_vec_max[:, :, i]
    m_H = torch.max(m_p_k, m_k)
    q_chunk_gated = q_chunk * torch.exp(p_k - m_H).unsqueeze(-1)

    denom_k_inter = torch.max(
        torch.abs(q_chunk_gated @ n_k_inter.unsqueeze(-1)), torch.exp(-m_k - m_H)
    )

    H_inter = q_chunk_gated @ C_k / denom_k_inter

    # ? Compute intra chunk contribution: H_intra
    # this is similar to the parallel version, but only for the current chunk
    log_fg_k = log_fgates[:, :, i].unsqueeze(-1)  # (B, NH, L, 1)
    log_ig_k = igs[:, :, i].unsqueeze(-1)  # (B, NH, L, 1)
    ltr = torch.tril(
        torch.ones(
            (L, L),
            dtype=torch.bool,
            device=_device,
        )
    )
    log_fg_k_cumsum = torch.cat(
        [
            torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device),
            torch.cumsum(log_fg_k, dim=-2),
        ],
        dim=-2,
    )  # (B, NH, L+1, 1)
    # for each batch/head this is a matrix of shape (L+1, L+1) containing the cumsum of the log forget gate values
    # in the second dimension (colum dimension). Each row has the same is a copy of the first row.
    # First entry of each row is zero.
    rep_log_fg_k_cumsum = log_fg_k_cumsum.repeat(1, 1, 1, L + 1)  # (B, NH, L+1, L+1)
    # Now in each row cut off / subtract the forgetgate values of the later timesteps
    # where col j > row i
    _log_fg_k_matrix = rep_log_fg_k_cumsum - rep_log_fg_k_cumsum.transpose(
        -2, -1
    )  # (B, NH, L+1, L+1)
    # Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied
    # to the input at timestep t
    log_fg_k_matrix = torch.where(
        ltr, _log_fg_k_matrix[:, :, 1:, 1:], -float("inf")
    )  # (B, NH, L, L)

    log_D_k = log_fg_k_matrix + log_ig_k.transpose(-2, -1)  # (B, NH, L, L)

    # compute the max state (for now isolated for intra chunk contribution)
    m_log_D_k = torch.max(log_D_k, dim=-1, keepdim=True).values

    log_D_k_stabilized = log_D_k - m_log_D_k
    D_k = torch.exp(log_D_k_stabilized)
    qk_k_matrix = q_chunk @ k_chunk.transpose(-2, -1)
    C_k_matrix = qk_k_matrix * D_k

    denom_k_intra = torch.maximum(
        C_k_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-m_log_D_k)
    )
    C_k_matrix_normalized = C_k_matrix / denom_k_intra  # TODO add eps

    H_intra = C_k_matrix_normalized @ v_chunk  # (B, NH, L, DH)
    H_states[:, :, i, :, :] = (denom_k_inter / denom_k_intra) * H_inter + H_intra

# H_y = rearrange(H_states, "b nh nc l dh -> b nh (nc l) dh")

# we do not need the first forget gate as this is applied to the first element
# log_fgates_cumsum = log_fgates[:, :, 1:].cumsum(-1)
# d_vec = torch.cat(
#     [torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), log_fgates_cumsum],
#     dim=-2,
# )

NameError: name 'queries' is not defined

In [27]:
H_states

tensor([[[[[ 0.1950, -0.2674, -0.1591, -0.0408,  0.0139],
           [-0.0457, -0.7729,  0.3219,  0.6750,  0.3412],
           [-0.0534, -1.1351,  0.4686,  0.9789,  0.5149],
           [-0.1144,  0.1129, -0.0369,  0.0488, -0.1858]],

          [[-0.2556,  1.1231, -0.4512, -0.4634, -1.2010],
           [-0.6378,  1.5420,  0.0900, -1.2322,  3.0609],
           [ 0.9447, -0.5071,  0.8299,  1.8393, -1.2458],
           [ 0.1122,  0.4685, -0.7218, -0.7698,  0.0353]],

          [[-2.5104,  1.3194, -2.5216, -2.7714,  3.5603],
           [ 0.7144, -0.3377, -0.2862, -0.4849, -1.0592],
           [-1.4214,  0.7989, -0.4866, -0.0828,  1.6035],
           [ 1.3098, -0.7683,  0.2306, -0.1338, -1.5235]]]]], device='cuda:0')

In [42]:
log_fgates, igs

(tensor([[[[-0.2472, -0.3480, -1.2512, -0.3286],
           [-1.0188, -1.0823, -0.8327, -0.8236],
           [-1.1829, -0.3297, -1.0177, -0.5052]]]], device='cuda:0'),
 tensor([[[[-0.3091, -0.5703,  1.8633, -0.4843],
           [ 0.4213, -0.3814,  0.3600,  0.0021],
           [-1.1131, -1.2531,  0.0982, -0.8470]]]], device='cuda:0'))

In [38]:
C_states

tensor([[[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

          [[-2.7854e-02, -5.3366e-04, -4.8855e-02,  2.1908e-02,  1.1594e-01],
           [ 7.2663e-01,  5.8536e-01,  1.6246e+00, -8.4919e-01,  9.5934e-01],
           [-1.5995e-01, -7.0658e-02, -3.0576e-01,  1.0001e-01, -1.4419e-01],
           [-8.1330e-01, -6.3639e-01, -1.7333e+00,  9.3910e-01, -9.0490e-01],
           [-3.8330e-01, -2.8181e-01, -7.6245e-01,  4.2472e-01, -3.2418e-01]],

          [[-2.4721e-01,  1.1252e+00,  1.0871e+00,  2.5634e+00,  9.1802e-02],
           [-2.1048e-01, -4.7870e-01,  3.1842e-01, -1.0948e+00,  4.0503e-02],
           [ 2.7336e-01, -9.8781e-02, -2.2610e-01, -8.4500e-

In [36]:
H_states

tensor([[[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

          [[-0.6741, -0.5397, -1.4885,  0.7784, -0.9014],
           [ 0.2604,  0.1975,  0.5540, -0.2985,  0.2428],
           [ 0.0798,  0.0606,  0.1691, -0.0888,  0.0913],
           [-0.0019, -0.0016, -0.0038,  0.0015, -0.0075]],

          [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]]], device='cuda:0')

In [16]:
n_states[:, :, -1].unsqueeze(-1).shape, q_chunk.shape

(torch.Size([1, 1, 5, 1]), torch.Size([1, 1, 4, 5]))

In [17]:
(q_chunk @ n_states[:, :, -1].unsqueeze(-1)).shape, m_k.shape

(torch.Size([1, 1, 4, 1]), torch.Size([1, 1, 1]))

In [18]:
torch.max(torch.abs(q_chunk @ n_states[:, :, -1].unsqueeze(-1)), m_k).shape

torch.Size([1, 1, 4, 1])

In [19]:
g_vec.shape, ks.shape, ks[:, :, 1, :, :].shape

IndexError: index 1 is out of bounds for dimension 2 with size 1

In [None]:
kv_gated.shape

torch.Size([1, 1, 3, 5, 5])

In [None]:
(ks[:, :, 1, :, :] * d_vec[:, :, 1, :]).shape

torch.Size([1, 1, 4, 5])

In [None]:
d_vec[:, :, 1, :, :].shape

torch.Size([1, 1, 4, 1])

In [None]:
d_vec, d_vec.shape

(tensor([[[[[ 0.7589],
            [-0.7418],
            [-2.5793],
            [-1.5436]],
 
           [[-1.1483],
            [-1.1596],
            [-4.4460],
            [-3.5700]],
 
           [[ 0.5241],
            [-1.4934],
            [-2.0288],
            [-1.9730]]]]], device='cuda:0'),
 torch.Size([1, 1, 3, 4, 1]))

In [None]:
d_vec.max(dim=-2).values

tensor([[[[ 0.7589],
          [-1.1483],
          [ 0.5241]]]], device='cuda:0')

In [None]:
vals, idxs = d_vec.max(dim=-2)
vals.shape, vals[:, :, 0].shape

(torch.Size([1, 1, 3, 1]), torch.Size([1, 1, 1]))

In [None]:
(ks.transpose(-1, -2) @ vs).shape

torch.Size([1, 1, 3, 5, 5])

In [None]:
ks.shape, d_vec.shape

(torch.Size([1, 1, 3, 4, 5]), torch.Size([1, 1, 3, 4, 1]))

In [None]:
(d_vec * ks).shape

torch.Size([1, 1, 3, 4, 5])

In [None]:
log_fgates[:, :, :, 1:].cumsum(-1).shape

torch.Size([1, 1, 3, 3])

In [None]:
d_vec.shape

torch.Size([1, 1, 3, 4])

In [None]:
d_vec.unsqueeze(-1).shape

torch.Size([1, 1, 3, 4, 1])