In [1]:
import torch

## Forget Gate Matrix Computation Variants

In [2]:
fgs = torch.arange(0, 10)
igs = torch.arange(0, 10) / 100.

In [3]:
fgs

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [4]:
fgs[:,None] - fgs[None,:]

tensor([[ 0, -1, -2, -3, -4, -5, -6, -7, -8, -9],
        [ 1,  0, -1, -2, -3, -4, -5, -6, -7, -8],
        [ 2,  1,  0, -1, -2, -3, -4, -5, -6, -7],
        [ 3,  2,  1,  0, -1, -2, -3, -4, -5, -6],
        [ 4,  3,  2,  1,  0, -1, -2, -3, -4, -5],
        [ 5,  4,  3,  2,  1,  0, -1, -2, -3, -4],
        [ 6,  5,  4,  3,  2,  1,  0, -1, -2, -3],
        [ 7,  6,  5,  4,  3,  2,  1,  0, -1, -2],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0, -1],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

In [5]:
def construct_log_gate_matrix_paper(fgs: torch.Tensor, igs: torch.Tensor) -> torch.Tensor:
    _device = fgs.device
    _dtype = fgs.dtype
    B, NH, S, _ = fgs.shape
    ltr = torch.tril(
        torch.ones(
            (S, S),
            dtype=torch.bool,
            device=_device,
        )
    )
    log_fgates_cumsum = torch.cat(
        [
            torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device),
            torch.cumsum(fgs, dim=-2),
        ],
        dim=-2,
    )  # (B, NH, S+1, 1)
    # for each batch/head this is a matrix of shape (S+1, S+1) containing the cumsum of the log forget gate values
    # in the second dimension (colum dimension). Each row has the same is a copy of the first row.
    # First entry of each row is zero.
    rep_log_fgates_cumsum = log_fgates_cumsum.repeat(
        1, 1, 1, S + 1
    )  # (B, NH, S+1, S+1)
    # Now in each row cut off / subtract the forgetgate values of the later timesteps
    # where col j > row i
    _log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(
        -2, -1
    )  # (B, NH, S+1, S+1)
    # Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied
    # to the input at timestep t
    log_fg_matrix = torch.where(
        ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")
    )  # (B, NH, S, S)

    # gate decay matrix D (combination of forget gate and input gate)
    log_D_matrix = log_fg_matrix + igs.transpose(-2, -1)  # (B, NH, S, S)
    return log_D_matrix

In [6]:
B = 1
NH = 1
S = 10

In [7]:
fgs = torch.ones(B, NH, S, 1)
igs = torch.zeros(B, NH, S, 1)

In [8]:
matD_paper = construct_log_gate_matrix_paper(fgs, igs)

In [9]:
matD_paper

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [1., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [2., 1., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [3., 2., 1., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [4., 3., 2., 1., 0., -inf, -inf, -inf, -inf, -inf],
          [5., 4., 3., 2., 1., 0., -inf, -inf, -inf, -inf],
          [6., 5., 4., 3., 2., 1., 0., -inf, -inf, -inf],
          [7., 6., 5., 4., 3., 2., 1., 0., -inf, -inf],
          [8., 7., 6., 5., 4., 3., 2., 1., 0., -inf],
          [9., 8., 7., 6., 5., 4., 3., 2., 1., 0.]]]])

In [10]:
fgs.squeeze(-1).shape

torch.Size([1, 1, 10])

In [11]:
fg_cumsum = torch.cumsum(fgs.squeeze(-1), dim=-1)

In [12]:
fg_cumsum

tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]]])

In [13]:
# construct the gate matrix via subtraction
fg_cumsum[:, :, :, None] - fg_cumsum[:, :, None, :]

tensor([[[[ 0., -1., -2., -3., -4., -5., -6., -7., -8., -9.],
          [ 1.,  0., -1., -2., -3., -4., -5., -6., -7., -8.],
          [ 2.,  1.,  0., -1., -2., -3., -4., -5., -6., -7.],
          [ 3.,  2.,  1.,  0., -1., -2., -3., -4., -5., -6.],
          [ 4.,  3.,  2.,  1.,  0., -1., -2., -3., -4., -5.],
          [ 5.,  4.,  3.,  2.,  1.,  0., -1., -2., -3., -4.],
          [ 6.,  5.,  4.,  3.,  2.,  1.,  0., -1., -2., -3.],
          [ 7.,  6.,  5.,  4.,  3.,  2.,  1.,  0., -1., -2.],
          [ 8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  0., -1.],
          [ 9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  0.]]]])

In [14]:
mask = torch.tril(torch.ones(S, S, dtype=torch.bool), -1)
mask

tensor([[False, False, False, False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False]])

In [15]:
# construct the matrix via repeating, masking and then cumsum
(fgs * mask).cumsum(dim=-2), fgs

(tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
           [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
           [2., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
           [3., 2., 1., 0., 0., 0., 0., 0., 0., 0.],
           [4., 3., 2., 1., 0., 0., 0., 0., 0., 0.],
           [5., 4., 3., 2., 1., 0., 0., 0., 0., 0.],
           [6., 5., 4., 3., 2., 1., 0., 0., 0., 0.],
           [7., 6., 5., 4., 3., 2., 1., 0., 0., 0.],
           [8., 7., 6., 5., 4., 3., 2., 1., 0., 0.],
           [9., 8., 7., 6., 5., 4., 3., 2., 1., 0.]]]]),
 tensor([[[[1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.]]]]))

## Tiled Forget Gate matrix computation

Compute a tile of the forgetgate matrix.

In [16]:
B = 1
NH = 1
S = 32

BQ = 8
BKV = 4

In [17]:
fgs = torch.ones(B*NH*S).reshape(B, NH, S)
fgs_cs = torch.cumsum(fgs, dim=-1)
fgs_rev_cs = torch.cumsum(fgs.flip(-1), dim=-1).flip(-1)
fgs_cs, fgs_rev_cs
igs = torch.arange(B*NH*S).reshape(B, NH, S) / 100.

In [18]:
idx_BQ = 2
idx_BKV = 5

In [19]:
fgs_chunk = fgs[:, :, idx_BKV * BKV : (idx_BKV + 1) * idx_BKV]
fgs_cs_chunk_Q = fgs_cs[:, :, idx_BQ * BQ : (idx_BQ + 1) * BQ]
fgs_cs_chunk_KV = fgs_cs[:, :, idx_BKV * BKV : (idx_BKV + 1) * BKV]

In [20]:
fgs_cs_chunk_Q[:, :, :, None] - fgs_cs_chunk_KV[:, :, None, :]

tensor([[[[-4., -5., -6., -7.],
          [-3., -4., -5., -6.],
          [-2., -3., -4., -5.],
          [-1., -2., -3., -4.],
          [ 0., -1., -2., -3.],
          [ 1.,  0., -1., -2.],
          [ 2.,  1.,  0., -1.],
          [ 3.,  2.,  1.,  0.]]]])

In [21]:
full_fgs_mat = construct_log_gate_matrix_paper(fgs.unsqueeze(-1), torch.zeros_like(fgs))
full_fgs_mat[:, :, idx_BQ * BQ : (idx_BQ + 1) * BQ, idx_BKV * BKV : (idx_BKV + 1) * BKV]

tensor([[[[-inf, -inf, -inf, -inf],
          [-inf, -inf, -inf, -inf],
          [-inf, -inf, -inf, -inf],
          [-inf, -inf, -inf, -inf],
          [0., -inf, -inf, -inf],
          [1., 0., -inf, -inf],
          [2., 1., 0., -inf],
          [3., 2., 1., 0.]]]])

In [22]:
def constuct_log_gate_matrix_tiled(fgs: torch.Tensor, igs: torch.Tensor, BQ: int, BKV: int, idx_BQ: int, idx_BKV, fgs_cs: torch.Tensor = None) -> torch.Tensor:
    B, NH, S = fgs.shape
    if fgs_cs is None:
        fgs_cs = torch.cumsum(fgs, dim=-1)
    fgs_cs_chunk_Q = fgs_cs[:, :, idx_BQ * BQ : (idx_BQ + 1) * BQ]
    fgs_cs_chunk_KV = fgs_cs[:, :, idx_BKV * BKV : (idx_BKV + 1) * BKV]
    
    fgate_tile = fgs_cs_chunk_Q[:, :, :, None] - fgs_cs_chunk_KV[:, :, None, :]
    
    igs_chunk = igs[:, :, idx_BKV * BKV : (idx_BKV + 1) * BKV]
    log_D_matrix = fgate_tile + igs_chunk

    # causal masking
    if idx_BKV * BKV >= idx_BQ * BQ:
        bq_idxes = torch.arange(idx_BQ * BQ, (idx_BQ + 1) * BQ)
        kv_idxes = torch.arange(idx_BKV * BKV, (idx_BKV + 1) * BKV)
        idx_mask = bq_idxes[:, None] - kv_idxes[None, :] # or bq_idxes[:, None] >= kv_idxes[None, :]
        log_D_matrix = torch.where(idx_mask < 0, -float("inf"), log_D_matrix)
    return log_D_matrix



In [23]:
full_fgs_mat = construct_log_gate_matrix_paper(fgs.unsqueeze(-1), igs.unsqueeze(-1))
full_fgs_mat[:, :, idx_BQ * BQ : (idx_BQ + 1) * BQ, idx_BKV * BKV : (idx_BKV + 1) * BKV]

tensor([[[[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf],
          [0.2000,   -inf,   -inf,   -inf],
          [1.2000, 0.2100,   -inf,   -inf],
          [2.2000, 1.2100, 0.2200,   -inf],
          [3.2000, 2.2100, 1.2200, 0.2300]]]])

In [24]:
tiled_fgs_mat = constuct_log_gate_matrix_tiled(fgs, igs, BQ, BKV, idx_BQ, idx_BKV)
tiled_fgs_mat

tensor([[[[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf],
          [0.2000,   -inf,   -inf,   -inf],
          [1.2000, 0.2100,   -inf,   -inf],
          [2.2000, 1.2100, 0.2200,   -inf],
          [3.2000, 2.2100, 1.2200, 0.2300]]]])

In [25]:
igs_chunk = igs[:, :, idx_BKV * BKV : (idx_BKV + 1) * BKV]
igs_chunk

tensor([[[0.2000, 0.2100, 0.2200, 0.2300]]])

In [26]:
bq_idxes = torch.arange(idx_BQ * BQ, (idx_BQ + 1) * BQ)
bq_idxes

tensor([16, 17, 18, 19, 20, 21, 22, 23])

In [27]:
kv_idxes = torch.arange(idx_BKV * BKV, (idx_BKV + 1) * BKV)
kv_idxes

tensor([20, 21, 22, 23])

In [28]:
bq_idxes[:, None] >= kv_idxes[None, :]

tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])