In [29]:
import torch

def delta_rule_recurrent(v, eta, h0=None):
    B, H, N, D = v.shape
    if h0 is None:
        h = torch.zeros(B, H, D, dtype=v.dtype, device=v.device)
    else:
        h = h0
    outs = []
    for t in range(N):
        h = (1 - eta[:, :, t, :]) * h + eta[:, :, t, :] * v[:, :, t, :]
        outs.append(h.unsqueeze(2))
    return torch.cat(outs, dim=2)

def delta_rule_closed_form(v, eta, h0=None):
    B, H, N, D = v.shape
    one_minus_eta = 1 - eta
    log_1m_eta = torch.log(one_minus_eta + 1e-8)
    log_cumsum = torch.cumsum(log_1m_eta, dim=2)
    log_cumsum_t = log_cumsum.unsqueeze(3)
    log_cumsum_j = log_cumsum.unsqueeze(2)
    prod = torch.exp(log_cumsum_t - log_cumsum_j)
    mask = torch.tril(torch.ones(N, N, device=v.device), diagonal=0).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
    prod = prod * mask
    ev = eta * v
    ev_j = ev.unsqueeze(2)
    h = torch.sum(prod * ev_j, dim=3)
    if h0 is not None:
        prod_h0 = torch.exp(log_cumsum)
        h = h + prod_h0 * h0.unsqueeze(2)
    return h

def gla_recurrent(q, k, v, lambd, S0=None):
    B, H, N, D = q.shape
    if S0 is None:
        S = torch.zeros(B, H, D, D, dtype=q.dtype, device=q.device)
    else:
        S = S0
    outs = []
    for t in range(N):
        kvT = torch.einsum('bhd,bhe->bhde', k[:, :, t, :], v[:, :, t, :])
        lmbd = lambd[:, :, t, :].reshape(B, H, 1, 1)
        S = lmbd * S + (1 - lmbd) * kvT
        o = torch.einsum('bhd,bhde->bhe', q[:, :, t, :], S)
        outs.append(o.unsqueeze(2))
    return torch.cat(outs, dim=2)

def gla_closed_form(q, k, v, lambd, S0=None):
    B, H, N, D = q.shape
    one_minus_lambd = 1 - lambd  # [B, H, N, 1]
    prod = torch.cumprod(lambd.flip(dims=[2]), dim=2).flip(dims=[2])
    prod = torch.cat([prod[:, :, 1:, :], torch.ones_like(prod[:, :, :1, :])], dim=2)
    weights = one_minus_lambd * prod  # [B, H, N, 1]
    kvT = torch.einsum('b h n d, b h n e -> b h n d e', k, v)  # [B, H, N, D, D]
    S = torch.cumsum(weights.unsqueeze(-1) * kvT, dim=2)  # [B, H, N, D, D]
    if S0 is not None:
        prod_S0 = torch.cumprod(lambd, dim=2)  # [B, H, N, 1]
        S = S + prod_S0.unsqueeze(-1) * S0.unsqueeze(2)  # [B, H, N, D, D]
    o = torch.einsum('b h n d, b h n d e -> b h n e', q, S)
    return o

def test_equivalence():
    torch.manual_seed(0)
    B, H, N, D = 2, 3, 8, 4
    v = torch.randn(B, H, N, D)
    eta = torch.sigmoid(torch.randn(B, H, N, D))
    h0 = torch.randn(B, H, D)

    # delta_rule
    h_recur = delta_rule_recurrent(v, eta, h0)
    h_closed = delta_rule_closed_form(v, eta, h0)
    print("delta_rule max abs diff:", (h_recur - h_closed).abs().max().item())

    # GLA
    q = torch.randn(B, H, N, D)
    k = torch.randn(B, H, N, D)
    v = torch.randn(B, H, N, D)
    lambd = torch.sigmoid(torch.randn(B, H, N, 1))
    S0 = torch.randn(B, H, D, D)

    o_recur = gla_recurrent(q, k, v, lambd, S0)
    o_closed = gla_closed_form(q, k, v, lambd, S0)
    print("GLA max abs diff:", (o_recur - o_closed).abs().max().item())
    # VÃ©rification des shapes :
    print("S0.unsqueeze(2):", S0.unsqueeze(2).shape)
    print("prod_S0.unsqueeze(-1):", torch.cumprod(lambd, dim=2).unsqueeze(-1).shape)
    print("o:", o_closed.shape)

if __name__ == "__main__":
    test_equivalence()


delta_rule max abs diff: 3.5762786865234375e-07
GLA max abs diff: 6.677595138549805
S0.unsqueeze(2): torch.Size([2, 3, 1, 4, 4])
prod_S0.unsqueeze(-1): torch.Size([2, 3, 8, 1, 1])
o: torch.Size([2, 3, 8, 4])
