In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_parallel import vlstm_fw_torch
from vlstm_recurrent import vlstm_recurrent_sequence_stabilized
from vlstm_chunkwise_parallel import vlstm_chunkwise_parallel
from einops import rearrange
from torch.nn import functional as F


# Match vLSTM chunkwise parallel to parallel


In [2]:
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 12
NH = 1
DH = 2
EPS = 0.0

In [4]:
torch.manual_seed(5)
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
igate_preacts = 5 * torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) / 10000
fgate_preacts = torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) +1 # / 100

In [6]:
# igate_preacts = 5 * torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) / 10000
# fgate_preacts = torch.arange(B * NH * S, dtype=DTYPE, device=DEVICE).reshape(B, NH, S, 1) / 100

In [7]:
# fgate_preacts, igate_preacts

In [8]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 12, 2])

In [9]:
y_p = vlstm_fw_torch(qs, ks, vs, igate_preacts, fgate_preacts, eps=EPS)

In [10]:
# y_r = vlstm_recurrent_sequence_stabilized(qs, ks, vs, igate_preacts, fgate_preacts, normalization_mode="max_abs_sum_C_1", eps=EPS)
# y_r, torch.allclose(y_p, y_r, atol=1e-5)

In [11]:
y_p

tensor([[[[-0.1036, -0.1770],
          [ 0.3631, -0.0846],
          [ 0.0722,  0.5899],
          [ 0.2939, -0.5616],
          [-0.9256, -0.9527],
          [-0.1738,  2.1624],
          [ 0.0913,  1.3176],
          [-0.4555, -0.8872],
          [ 0.9848,  1.8037],
          [ 0.5587,  2.6535],
          [ 0.2654, -1.3075],
          [ 0.2538,  4.0591]]]], device='cuda:0')

In [12]:
y_cp = vlstm_chunkwise_parallel(qs, ks, vs, igate_preacts, fgate_preacts, chunk_size=4)
y_cp

log_fgates: torch.Size([1, 1, 3, 4])
tensor([[[[-3.1326e-01, -1.2693e-01, -4.8587e-02, -1.8150e-02],
          [-6.7153e-03, -2.4757e-03, -9.1147e-04, -3.3541e-04],
          [-1.2340e-04, -4.5399e-05, -1.6702e-05, -6.1442e-06]]]],
       device='cuda:0')
p_vec_f: torch.Size([1, 1, 3, 4])
tensor([[[[-3.1326e-01, -4.4019e-01, -4.8878e-01, -5.0693e-01],
          [-6.7153e-03, -9.1910e-03, -1.0103e-02, -1.0438e-02],
          [-1.2340e-04, -1.6880e-04, -1.8550e-04, -1.9165e-04]]]],
       device='cuda:0')
q_vec_f: torch.Size([1, 1, 3, 4])
tensor([[[[-1.9367e-01, -6.6737e-02, -1.8150e-02,  0.0000e+00],
          [-3.7226e-03, -1.2469e-03, -3.3541e-04,  0.0000e+00],
          [-6.8245e-05, -2.2846e-05, -6.1442e-06,  0.0000e+00]]]],
       device='cuda:0')
g_vec: torch.Size([1, 1, 3])
tensor([[[-5.0693e-01, -1.0438e-02, -1.9165e-04]]], device='cuda:0')
log_fg_k_matrix: torch.Size([1, 1, 4, 4])
tensor([[[[ 0.0000,    -inf,    -inf,    -inf],
          [-0.1269,  0.0000,    -inf,    -inf],
  

  from .autonotebook import tqdm as notebook_tqdm


tensor([[[[[-0.1036, -0.1770],
           [ 0.3631, -0.0846],
           [ 0.0722,  0.5899],
           [ 0.2939, -0.5616]],

          [[-0.9256, -0.9527],
           [-0.1738,  2.1624],
           [ 0.0913,  1.3176],
           [-0.4555, -0.8872]],

          [[ 0.9848,  1.8037],
           [ 0.5587,  2.6535],
           [ 0.2654, -1.3075],
           [ 0.2538,  4.0591]]]]], device='cuda:0')

In [13]:
# y_cp
# tensor([[[[[-0.0192, -0.0328,  0.0145],
#            [-0.0261,  0.8259, -0.9299],
#            [ 0.1474, -0.2691, -0.9078],
#            [-0.9536, -0.7056, -0.2249]],

#           [[-0.1601,  0.1721,  0.6033],
#            [-0.4301,  0.2871, -0.6484],
#            [-0.5925, -0.2081,  0.4734],
#            [ 0.0110,  0.5813, -0.0628]],

#           [[ 3.6917,  0.4356,  1.7668],
#            [-5.2968, -1.1374, -2.3225],
#            [ 1.3773, -0.6453,  0.6941],
#            [-2.3798, -2.3140, -3.1298]]]]], device='cuda:0')

In [14]:
rearrange(y_p, "b nh (nc l) dh -> b nh nc l dh", l=4)

tensor([[[[[-0.1036, -0.1770],
           [ 0.3631, -0.0846],
           [ 0.0722,  0.5899],
           [ 0.2939, -0.5616]],

          [[-0.9256, -0.9527],
           [-0.1738,  2.1624],
           [ 0.0913,  1.3176],
           [-0.4555, -0.8872]],

          [[ 0.9848,  1.8037],
           [ 0.5587,  2.6535],
           [ 0.2654, -1.3075],
           [ 0.2538,  4.0591]]]]], device='cuda:0')

In [15]:
rearrange(y_p, "b nh (nc l) dh -> b nh nc l dh", l=4) - y_cp

tensor([[[[[ 0.0000e+00,  0.0000e+00],
           [ 2.9802e-08, -2.2352e-08],
           [-2.2352e-08,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00]],

          [[ 5.9605e-08,  1.1921e-07],
           [-4.4703e-08, -4.7684e-07],
           [ 1.4901e-08, -1.1921e-07],
           [-5.9605e-08, -1.7881e-07]],

          [[ 1.1921e-07, -2.3842e-07],
           [-1.7881e-07,  0.0000e+00],
           [ 0.0000e+00,  2.3842e-07],
           [-1.4901e-07,  4.7684e-07]]]]], device='cuda:0')

In [16]:
# v1:
# tensor([[[[[-5.5879e-09, -7.4506e-09,  3.7253e-09],
#            [-2.4214e-08,  5.9605e-08, -1.1921e-07],
#            [-1.4901e-08, -2.9802e-08,  0.0000e+00],
#            [ 1.1921e-07,  1.7881e-07,  7.4506e-08]],

#           [[ 2.7397e-04,  1.7940e-03, -2.2005e-03],
#            [-3.2905e-04, -1.9018e-03,  5.8073e-04],
#            [ 2.1350e-04, -3.5465e-04, -1.7783e-04],
#            [ 2.2422e-04,  7.9042e-04,  5.8800e-05]],

#           [[ 1.2011e-02, -1.8224e-04,  5.3725e-03],
#            [-1.3653e-02,  1.6749e-04, -6.2060e-03],
#            [ 3.4611e-03, -5.6410e-04,  8.6242e-04],
#            [-4.5993e-03,  1.3285e-03, -3.2961e-03]]]]], device='cuda:0')
# v2: 
# tensor([[[[[-5.5879e-09, -7.4506e-09,  3.7253e-09],
#            [-2.4214e-08,  5.9605e-08, -1.1921e-07],
#            [-1.4901e-08, -2.9802e-08,  0.0000e+00],
#            [ 1.1921e-07,  1.7881e-07,  7.4506e-08]],

#           [[ 3.5303e-04,  2.3117e-03, -2.8356e-03],
#            [-8.3545e-04, -4.8289e-03,  1.4748e-03],
#            [-3.3039e-04,  5.4850e-04,  2.7516e-04],
#            [-6.4742e-05, -2.2787e-04, -1.7010e-05]],

#           [[-4.3964e-04,  6.8247e-06, -1.9681e-04],
#            [ 1.7977e-04, -2.2650e-06,  8.1778e-05],
#            [-1.6689e-05,  2.7418e-06, -4.1723e-06],
#            [ 8.3447e-06, -2.3842e-06,  5.7220e-06]]]]], device='cuda:0')

NameError: name 'tensor' is not defined

In [None]:
queries = qs
keys = ks
values = vs
igate_preact = igate_preacts
fgate_preact = fgate_preacts
chunk_size = 4

In [None]:
B, NH, S, DH = queries.shape
_dtype, _device = queries.dtype, queries.device
qs = rearrange(queries, "b nh (nc l) dh -> b nh nc l dh", l=chunk_size) * (DH**-0.5)
ks = rearrange(keys, "b nh (nc l) dh -> b nh nc l dh", l=chunk_size)
vs = rearrange(values, "b nh (nc l) dh -> b nh nc l dh", l=chunk_size)
_, _, NC, L, _ = qs.shape
igs = rearrange(igate_preact, "b nh (nc l) 1 -> b nh nc l", l=chunk_size)
fgs = rearrange(fgate_preact, "b nh (nc l) 1 -> b nh nc l", l=chunk_size)

# compute the gates, the g and the p and q vectors
log_fgates = F.logsigmoid(fgs)

p_vec_f = torch.cat(
    [
        torch.zeros((B, NH, NC, 1), dtype=_dtype, device=_device),
        log_fgates[:, :, :, :-1].cumsum(-1),
    ],
    dim=-1,
)

q_vec_f_raw = torch.cat(
    [
        torch.zeros((B, NH, NC, 1), dtype=_dtype, device=_device),
        log_fgates[:, :, :, 1:].cumsum(-1),
    ],
    dim=-1,
)
q_vec_f = log_fgates[:, :, :, 1:].sum(-1, keepdim=True) - q_vec_f_raw

p_vec = p_vec_f + igs
q_vec = q_vec_f + igs
g_vec = log_fgates.sum(-1)

# get the maximum values per chunk for p and q
p_vec_max = p_vec.max(-1).values
q_vec_max = q_vec.max(-1).values


# loop 1: materialize the  C_k, n_k, m_k
C_states = torch.zeros((B, NH, NC, DH, DH), dtype=_dtype, device=_device)
n_states = torch.zeros((B, NH, NC, DH), dtype=_dtype, device=_device)
m_states = torch.zeros((B, NH, NC, 1), dtype=_dtype, device=_device)

m_k = torch.zeros((B, NH, 1), dtype=_dtype, device=_device)
m_prev_k = torch.zeros((B, NH, 1), dtype=_dtype, device=_device)
C_k = torch.zeros((B, NH, DH, DH), dtype=_dtype, device=_device)
C_prev_k = torch.zeros((B, NH, DH, DH), dtype=_dtype, device=_device)
n_k = torch.zeros((B, NH, DH), dtype=_dtype, device=_device)
n_prev_k = torch.zeros((B, NH, DH), dtype=_dtype, device=_device)
for k in range(1, NC):
    i = k - 1
    # m_k
    m_q_k = q_vec_max[:, :, i]
    g_k = g_vec[:, :, i]
    m_k = torch.max(g_k + m_prev_k, m_q_k)
    m_states[:, :, k] = m_k

    # C_k
    k_chunk = ks[:, :, i, :, :].clone()
    v_chunk = vs[:, :, i, :, :].clone()
    q_k = q_vec[:, :, i, :].clone()
    k_chunk_gated = k_chunk * torch.exp(q_k - m_k).unsqueeze(-1)

    C_k = (
        torch.exp(g_k + m_prev_k - m_k) * C_prev_k
        + k_chunk_gated.transpose(-2, -1) @ v_chunk
    )
    C_states[:, :, k] = C_k

    # n_k
    n_k = torch.exp(g_k + m_prev_k - m_k) * n_prev_k + k_chunk_gated.transpose(
        -2, -1
    ).sum(-1)
    n_states[:, :, k] = n_k

    # move to the next iteration
    m_prev_k = m_k
    C_prev_k = C_k
    n_prev_k = n_k

# loop 2: compute the H_states
H_states = torch.zeros((B, NH, NC, L, DH), dtype=_dtype, device=_device)
for k in range(1, NC + 1):
    i = k - 1

    # load C_k, n_k, m_k
    C_k = C_states[:, :, i]
    n_k_inter = n_states[:, :, i]
    m_k = m_states[:, :, i]
    # load q, k, v chunks
    q_chunk = qs[:, :, i, :, :].clone()
    k_chunk = ks[:, :, i, :, :].clone()
    v_chunk = vs[:, :, i, :, :].clone()

    # ? Compute inter chunk contribution: H_inter
    p_k = p_vec[:, :, i, :].clone()

    m_p_k = p_vec_max[:, :, i]
    m_H = torch.max(m_p_k, m_k)
    q_chunk_gated = q_chunk * torch.exp(p_k - m_H).unsqueeze(-1)

    denom_k_inter = torch.max(
        torch.abs(q_chunk_gated @ n_k_inter.unsqueeze(-1)), torch.exp(-m_k - m_H)
    )

    H_inter = q_chunk_gated @ C_k / denom_k_inter

    # ? Compute intra chunk contribution: H_intra
    # this is similar to the parallel version, but only for the current chunk
    log_fg_k = log_fgates[:, :, i].unsqueeze(-1)  # (B, NH, L, 1)
    log_ig_k = igs[:, :, i].unsqueeze(-1)  # (B, NH, L, 1)
    ltr = torch.tril(
        torch.ones(
            (L, L),
            dtype=torch.bool,
            device=_device,
        )
    )
    log_fg_k_cumsum = torch.cat(
        [
            torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device),
            torch.cumsum(log_fg_k, dim=-2),
        ],
        dim=-2,
    )  # (B, NH, L+1, 1)
    # for each batch/head this is a matrix of shape (L+1, L+1) containing the cumsum of the log forget gate values
    # in the second dimension (colum dimension). Each row has the same is a copy of the first row.
    # First entry of each row is zero.
    rep_log_fg_k_cumsum = log_fg_k_cumsum.repeat(1, 1, 1, L + 1)  # (B, NH, L+1, L+1)
    # Now in each row cut off / subtract the forgetgate values of the later timesteps
    # where col j > row i
    _log_fg_k_matrix = rep_log_fg_k_cumsum - rep_log_fg_k_cumsum.transpose(
        -2, -1
    )  # (B, NH, L+1, L+1)
    # Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied
    # to the input at timestep t
    log_fg_k_matrix = torch.where(
        ltr, _log_fg_k_matrix[:, :, 1:, 1:], -float("inf")
    )  # (B, NH, L, L)

    log_D_k = log_fg_k_matrix + log_ig_k.transpose(-2, -1)  # (B, NH, L, L)

    # compute the max state (for now isolated for intra chunk contribution)
    m_log_D_k = torch.max(log_D_k, dim=-1, keepdim=True).values

    log_D_k_stabilized = log_D_k - m_log_D_k
    D_k = torch.exp(log_D_k_stabilized)
    qk_k_matrix = q_chunk @ k_chunk.transpose(-2, -1)
    C_k_matrix = qk_k_matrix * D_k

    denom_k_intra = torch.maximum(
        C_k_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-m_log_D_k)
    )
    C_k_matrix_normalized = C_k_matrix / denom_k_intra  # TODO add eps

    H_intra = C_k_matrix_normalized @ v_chunk  # (B, NH, L, DH)
    H_states[:, :, i, :, :] = (denom_k_inter / denom_k_intra) * H_inter + H_intra

# H_y = rearrange(H_states, "b nh nc l dh -> b nh (nc l) dh")

# we do not need the first forget gate as this is applied to the first element
# log_fgates_cumsum = log_fgates[:, :, 1:].cumsum(-1)
# d_vec = torch.cat(
#     [torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), log_fgates_cumsum],
#     dim=-2,
# )

In [None]:
H_states

tensor([[[[[-0.9131, -0.6950,  0.3553],
           [-0.9939,  0.3550,  1.3450],
           [-1.2052,  0.5400,  1.5244],
           [ 0.5647, -0.1064, -0.7442]],

          [[ 3.4731, -0.7237, -2.6495],
           [ 0.2019,  0.2291,  0.4814],
           [ 0.6854,  0.2205, -0.3705],
           [ 0.8782,  0.2471, -2.1387]],

          [[-0.2243, -0.0334,  0.2088],
           [ 0.6945,  0.0576, -0.6253],
           [-0.6054, -0.6223,  0.4894],
           [-0.6631,  3.1439, -1.1159]]]]], device='cuda:0')

In [None]:
log_fgates, igs

(tensor([[[[-1.0596, -1.0655, -1.0648, -1.6584],
           [-0.4953, -1.0422, -1.2319, -1.2757],
           [-0.6950, -0.9492, -0.5191, -0.2592]]]], device='cuda:0'),
 tensor([[[[-0.2963,  2.6764, -0.1408, -0.8441],
           [ 0.2905, -0.2838, -1.4535,  2.3737],
           [-0.0177, -2.7884, -0.3788,  0.7046]]]], device='cuda:0'))

In [None]:
C_states

tensor([[[[[ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000]],

          [[ 0.7340, -0.6587, -1.6607],
           [ 0.0045,  0.4336,  0.6466],
           [-0.9611,  0.2796,  1.4387]],

          [[ 0.5968,  1.9466, -1.5605],
           [ 0.1144,  0.3455, -0.2782],
           [ 0.2439,  0.8581, -0.6792]]]]], device='cuda:0')

In [None]:
H_states

tensor([[[[[-0.9131, -0.6950,  0.3553],
           [-0.9939,  0.3550,  1.3450],
           [-1.2052,  0.5400,  1.5244],
           [ 0.5647, -0.1064, -0.7442]],

          [[ 3.4731, -0.7237, -2.6495],
           [ 0.2019,  0.2291,  0.4814],
           [ 0.6854,  0.2205, -0.3705],
           [ 0.8782,  0.2471, -2.1387]],

          [[-0.2243, -0.0334,  0.2088],
           [ 0.6945,  0.0576, -0.6253],
           [-0.6054, -0.6223,  0.4894],
           [-0.6631,  3.1439, -1.1159]]]]], device='cuda:0')

In [None]:
n_states[:, :, -1].unsqueeze(-1).shape, q_chunk.shape

(torch.Size([1, 1, 3, 1]), torch.Size([1, 1, 4, 3]))

In [None]:
(q_chunk @ n_states[:, :, -1].unsqueeze(-1)).shape, m_k.shape

(torch.Size([1, 1, 4, 1]), torch.Size([1, 1, 1]))

In [None]:
torch.max(torch.abs(q_chunk @ n_states[:, :, -1].unsqueeze(-1)), m_k).shape

torch.Size([1, 1, 4, 1])

In [None]:
g_vec.shape, ks.shape, ks[:, :, 1, :, :].shape

(torch.Size([1, 1, 3]), torch.Size([1, 1, 3, 4, 3]), torch.Size([1, 1, 4, 3]))

In [None]:
kv_gated.shape

NameError: name 'kv_gated' is not defined

In [None]:
(ks[:, :, 1, :, :] * d_vec[:, :, 1, :]).shape

torch.Size([1, 1, 4, 5])

In [None]:
d_vec[:, :, 1, :, :].shape

torch.Size([1, 1, 4, 1])

In [None]:
d_vec, d_vec.shape

(tensor([[[[[ 0.7589],
            [-0.7418],
            [-2.5793],
            [-1.5436]],
 
           [[-1.1483],
            [-1.1596],
            [-4.4460],
            [-3.5700]],
 
           [[ 0.5241],
            [-1.4934],
            [-2.0288],
            [-1.9730]]]]], device='cuda:0'),
 torch.Size([1, 1, 3, 4, 1]))

In [None]:
d_vec.max(dim=-2).values

tensor([[[[ 0.7589],
          [-1.1483],
          [ 0.5241]]]], device='cuda:0')

In [None]:
vals, idxs = d_vec.max(dim=-2)
vals.shape, vals[:, :, 0].shape

(torch.Size([1, 1, 3, 1]), torch.Size([1, 1, 1]))

In [None]:
(ks.transpose(-1, -2) @ vs).shape

torch.Size([1, 1, 3, 5, 5])

In [None]:
ks.shape, d_vec.shape

(torch.Size([1, 1, 3, 4, 5]), torch.Size([1, 1, 3, 4, 1]))

In [None]:
(d_vec * ks).shape

torch.Size([1, 1, 3, 4, 5])

In [None]:
log_fgates[:, :, :, 1:].cumsum(-1).shape

torch.Size([1, 1, 3, 3])

In [None]:
d_vec.shape

torch.Size([1, 1, 3, 4])

In [None]:
d_vec.unsqueeze(-1).shape

torch.Size([1, 1, 3, 4, 1])