In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from gla_chunkwise_recurrent import torch_simple_gla, torch_simple_gla_recurrent

# Understand GLA chunkwise recurrent formulation

- Simple GLA

Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA.

$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 12
NH = 1
DH = 5

In [4]:
gs = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)

In [5]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 12, 5])

In [6]:
y_chunk = torch_simple_gla(qs, ks, vs, gs, chunk_size=4)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
y_chunk, y_chunk.shape

(tensor([[[[-0.8802, -0.5078,  0.0239, -0.0394],
           [-1.6906, -0.9773, -0.4512,  1.2322],
           [ 0.6835,  0.3874,  0.3577, -0.9209],
           [ 0.1452,  1.0253, -0.2556,  0.8121],
           [ 1.7869,  0.9745, -0.9807, -0.9227],
           [-0.8809,  1.7100, -0.7934,  0.9973],
           [-0.0363, -0.4424,  0.0630,  0.0519],
           [ 1.3853, -0.7348, -2.6967, -0.7480],
           [-0.5974,  0.9879, -0.3555,  0.5916],
           [ 1.0301,  1.1901, -0.8591, -1.0535],
           [ 0.1138,  1.4598, -3.1535, -1.8314],
           [-1.4128, -2.8261,  2.1200,  3.5042]]]], device='cuda:0'),
 torch.Size([1, 1, 12, 4]))

In [9]:
y_recurrent = torch_simple_gla_recurrent(qs, ks, vs, gs, chunk_size=4)

In [10]:
y_recurrent

tensor([[[[-0.8802, -0.5078,  0.0239, -0.0394],
          [-1.6906, -0.9773, -0.4512,  1.2322],
          [ 0.6835,  0.3874,  0.3577, -0.9209],
          [ 0.1452,  1.0253, -0.2556,  0.8121],
          [ 1.7869,  0.9745, -0.9807, -0.9227],
          [-0.8809,  1.7100, -0.7934,  0.9973],
          [-0.0363, -0.4424,  0.0630,  0.0519],
          [ 1.3853, -0.7348, -2.6967, -0.7480],
          [-0.5974,  0.9879, -0.3555,  0.5916],
          [ 1.0301,  1.1901, -0.8591, -1.0535],
          [ 0.1138,  1.4598, -3.1535, -1.8314],
          [-1.4128, -2.8261,  2.1200,  3.5042]]]], device='cuda:0')