# minGRU and minLSTM

https://arxiv.org/pdf/2410.01201

In [None]:
import os

%env PYTORCH_ENABLE_MPS_FALLBACK=1

## Parallel scan

https://arxiv.org/html/2311.06281v4


Sample code for computing the sequence

$$x_t=a_t x_{t-1} + b_t$$

efficiently in parallel, given
$t=(1,2,…,n), a_t∈R^n, b_t∈R^n$ , and initial value
$x_0 ∈ R$.

See "Efficient Parallelization of a Ubiquitious Sequential Computation" (Heinsen, 2023).

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from icecream import ic


In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

## naive sequential implementation

In [None]:
def naively_compute_sequentially(coeffs, values):
    x = [values[0]]  # x_0
    ic(x)

    for a, b in zip(coeffs, values[1:]):
        x.append(a * x[-1] + b)
        ic(x, x[-1], a, b)

    return torch.stack(x)


device = "mps"  # change as necessary
seq_len = 5  # _000_000  # change as you wish

# Generate some random input data:
coeffs = torch.randn(seq_len, device=device)
values = torch.randn(1 + seq_len, device=device)  # includes initial value
ic(coeffs, values)
# Compute the sequence:
x = naively_compute_sequentially(coeffs, values)  # includes initial value

## Parallel implementation


$$\log x_t = a_t^* + \log(x_0 + b_t^*) \tag{1}$$

where $a_t^*$ and $b_t^*$ are the 2 prefix (cumulative) sums:

$$a_t^*=\sum_{t}^{cum} \log a_t \tag{2} $$

$$b_t^*=\sum_{t}^{cum} e^{\log b_t-a_t^*} \tag{3} $$

Then we obtain $x_t$ with elementwise exponentiations

$$x_t=e^{a_t^*+\log(x_0 + b_t^*)} \tag{4}$$

It execute the same computations in parallel -- or more precisely, as a composition of two prefix sums, each of which is executable in parallel.

### General Case: If Any Input Can Be Negative
If any inputs are negative, their logarithms are complex numbers:
$$\ln(x)=\ln(\lvert x \lvert)+i\pi$$

In [None]:
def complex_log_impl(float_input, eps=1e-6):
    eps = float_input.new_tensor(eps)
    real = float_input.abs().maximum(eps).log()
    imag = (float_input < 0).to(float_input.dtype) * torch.pi
    return torch.complex(real, imag)


complex_log_impl(torch.tensor(-2.0))

In [None]:
def complex_log(float_input: torch.Tensor):
    # using torch native implem.
    return torch.log(float_input.to(dtype=torch.complex64))


cl = complex_log(torch.tensor(-2.0))
cl, cl.real, cl.imag

In [None]:
def parallel_scan_log_(log_coeffs, log_values):
    # log_coeffs: (batch_size, seq_len, input_size)
    # log_values: (batch_size, seq_len + 1, input_size)
    a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))
    log_h0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=1)
    log_h = a_star + log_h0_plus_b_star
    torch.exp(log_h)[:, 1:]


def parallel_scan_log(a, b):
    # log_coeffs: (batch_size, seq_len, input_size)
    # log_values: (batch_size, seq_len + 1, input_size)
    a_b, a_s, a_h = a.shape
    b_b, b_s, b_h = b.shape
    assert a_b == b_b and a_h == b_h and a_s + 1 == b_s

    log_a = complex_log(a)
    ic(log_a.shape)
    log_b = complex_log(b)
    ic(log_b.shape)

    # Split into real and imaginary parts
    log_a_real, log_a_imag = log_a.real, log_a.imag
    log_b_real, log_b_imag = log_b.real, log_b.imag

    # Compute cumsum separately for real/imag parts
    # Add padding row of zeros at start
    a_star_real = F.pad(torch.cumsum(log_a_real, dim=1), (0, 0, 1, 0))
    a_star_imag = F.pad(torch.cumsum(log_a_imag, dim=1), (0, 0, 1, 0))

    # Recombine into complex
    a_star = torch.complex(a_star_real, a_star_imag)
    shifted = log_b_real - a_star
    # adds one row of zeros at the beginning (top) of the tensor resulting from the cumulative sum.
    # so shapes of a and b are aligned
    # a_star = F.pad(torch.cumsum(log_a, dim=1), (0, 0, 1, 0))  # eq (2)
    ic(a_star.shape)
    assert a_star[0, 0, 0] == 0.0, "There must be a row of 0. added"

    log_h0_plus_b_star = torch.logcumsumexp(shifted, dim=1)  # eq (1)
    log_h = a_star + log_h0_plus_b_star  # eq (1)
    h = torch.exp(log_h)[:, 1:].real  # eq(4)
    return h


ic.enable()
device = "mps"  # change as necessary
seq_len = 5  # 10_000_000  # change as you wish

# take 0.1 of prev value and add it to current value
coeffs = torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.1]]).to(device).unsqueeze(-1)
values = torch.tensor([[0.0, -1.0, -2.0, -3.0, -4.0, -5.0]]).to(device).unsqueeze(-1)

# Compute the sequence:
x = parallel_scan_log(coeffs, values)
ic(x)
assert x.shape == coeffs.shape

## minGRU on shaekspeare

In [None]:
with open("input.txt", "r") as f:
    text = f.read()


In [None]:
start_val = int(np.ceil(len(text) * 0.8))
train_text = text[:start_val]
val_text = text[start_val:]
train_text[-10:], val_text[:10]

In [None]:
class ShakespeareDataSet(Dataset):
    def __init__(self, text: str, seq_len: int = 10):
        self.text = text
        self.chars = sorted(set(text))
        self.char2idx = {c: i for i, c in enumerate(self.chars)}
        self.idx2char = {i: c for i, c in enumerate(self.chars)}
        self.data = [self.char2idx[c] for c in text]
        self.vocab_size = len(self.chars)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.text) - self.seq_len

    def __getitem__(self, idx: int):
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1 : idx + self.seq_len + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)


ds = ShakespeareDataSet(text, seq_len=20)

x, y = next(iter(ds))
ic(x, y)

batch_size = 256
m = len(text) / batch_size
dl = DataLoader(ds, batch_size=batch_size)

x, y = next(iter(dl))
ic(x.shape, y.shape)

for i, (x, y) in enumerate(dl):
    break

The continuous function g ensures that $\widetilde{h}_t ← g(Linear_{d_h}(xt))$ is positive

In [None]:
def g(x):
    return torch.where(x >= 0, x + 0.5, torch.sigmoid(x))


def log_g(x):
    return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))


sx = torch.tensor([[0.0, -0.2, 0.3], [0.8, -0.9, 1.5]])
g_x = g(sx)
log_g_x = log_g(sx)
ic(sx, g_x, log_g_x)


In [None]:
class ShakespeareMinGRU(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.embeddings = nn.Embedding(vocab_size, hidden_size)
        self.linear_z = nn.Linear(hidden_size, hidden_size)
        self.linear_h = nn.Linear(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x_t: torch.Tensor, h_0: torch.Tensor):
        b, s = x_t.shape
        assert h_0.shape == (b, 1, self.hidden_size)
        # x_t = x_t.unsqueeze(-1)
        x_emb = self.embeddings(x_t)
        ic(x.shape, x_emb.shape, h_0.shape)
        k = self.linear_z(x_emb)
        ic(k.shape)
        log_z = -F.softplus(-k)
        ic(log_z.shape)
        log_coeffs = -F.softplus(k)
        log_h_0 = log_g(h_0)
        log_tilde_h = log_g(self.linear_h(x_emb))
        ic(log_coeffs.shape, log_h_0.shape, log_z.shape, log_tilde_h.shape)
        cat = torch.cat([log_h_0, log_z + log_tilde_h], dim=1)
        ic(cat.shape)
        h = parallel_scan_log(log_coeffs, cat)
        logits = self.fc(h)
        return logits, h

    def init_hidden(self, batch_size: int):
        return torch.zeros((batch_size, 1, self.hidden_size))


ic.enable()
ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=False)
batch_size = 8
dl = DataLoader(ds, batch_size=batch_size)
x, y = next(iter(dl))
x = x.to(device)
y = y.to(device)
model = ShakespeareMinGRU(vocab_size=ds.vocab_size, hidden_size=12).to(device)
h0 = model.init_hidden(batch_size=batch_size).to(device)
logits, h = model(x, h0)
ic(h.shape);


In [None]:
seq_length = 100
hidden_size = 256
num_layers = 1  # Using single layer as in the paper's minGRU/minLSTM
batch_size = 64
num_epochs = 1
learning_rate = 0.002

model = ShakespeareMinGRU(vocab_size=ds.vocab_size, hidden_size=32).to(device)
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:


train_ds = ShakespeareDataSet(train_text, seq_len=seq_length)
val_ds = ShakespeareDataSet(val_text, seq_len=seq_len)
train_dl = DataLoader(train_ds, batch_size=batch_size)
model.train()
for epoch in range(num_epochs):
    for i, (inputs, targets) in enumerate(train_dl):
        bs, sl = inputs.shape
        h0 = model.init_hidden(bs).to(device)
        inputs, targets = inputs.to(device), targets.to(device)
        logits, h = model(inputs, h0)
        logits = logits.reshape(-1, ds.vocab_size)
        targets = targets.reshape(-1)
        ic(logits.shape, targets.shape)
        loss = criterion(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dl)}], Loss: {loss.item():.4f}"
            )
