In [2]:
import torch
import torch.nn as nn
from typing import List, Tuple

In [2]:
class DisentangledTransformer(nn.Module):
    # vocab_size is static (not a parameter) but stored in the instance
    vocab_size: int

    def __init__(self, seq_len: int, vocab_size: int, heads: List[int]):
        super().__init__()
        self.vocab_size = vocab_size

        # buffers for learnable/non-learnable state
        self.A = nn.ParameterList()
        
        d = seq_len + vocab_size
        
        for n_head in heads:
            # attention head
            A_i = nn.Parameter(torch.zeros(n_head, d, d))
            self.A.append(A_i)
            d *= 1 + n_head
        
        # output layer
        self.W = nn.Parameter(torch.zeros(d, vocab_size))

    # attn(X,A) := softmax(XAX^T)X
    def attn(self, x: torch.Tensor, A_i: torch.Tensor) -> torch.Tensor:
        # x shape = (..., L, d) 
        # A_i shape = (n_head, d, d)
        
        L = x.shape[-2]
        
        # (..., L, d) @ (n_head, d, d) -> (n_head, ..., L, d) 
        # (n_head, ..., L, d) @ (..., d, L) -> (n_head, ..., L, L)
        attn_logits = torch.einsum("...ij,hjk,...lk->h...il", x, A_i, x)
        
        # causal mask (Lower triangular mask) ... attention doesn't go backwards
        causal_mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device=x.device))
        
        # apply the mask: set upper triangle to -inf
        attn_logits = torch.where(
            causal_mask,
            attn_logits,
            torch.tensor(-torch.inf, dtype=attn_logits.dtype, device=x.device)
        )
        
        # softmax over final element of sequence
        attn = nn.functional.softmax(attn_logits, dim=-1)
        
        # Attn_Weights X
        # (n_head, ..., L, L) @ (..., L, d) -> (n_head, ..., L, d)
        attn_output = torch.einsum("h...ij,...jk->h...ik", attn, x)
        
        # Find position of L (sequence length dimension)
        L_pos = -2 - (x.ndim - 2)
        
        # move head dim (0) to L_pos + 1
        permute_dims = list(range(1, attn_output.ndim))
        permute_dims.insert(L_pos + 1, 0)
        
        attn_output = attn_output.permute(permute_dims)
        
        return attn_output # shape: (..., L, n_head, d)


    # embed(x_i) := [ one_hot(x) | one_hot(i) ]
    def embed(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (..., L) - assumes integer indices
        
        # one hot Token Embeddings 
        wte = nn.functional.one_hot(x, num_classes=self.vocab_size).float()
        # wte shape = (..., L, vocab_size)
        
        # one hot Positional Embeddings 
        L = x.shape[-1]
        # wpe shape = (L, L)
        wpe = torch.eye(L, dtype=wte.dtype, device=x.device)
        
        # expand WPE to match the batch dimensions of X
        # desired shape = (..., L, L)
        target_shape = list(x.shape) + [L]
        wpe = wpe.expand(target_shape)

        # concatenate everything
        return torch.cat([wte, wpe], dim=-1)
        # final shape = (..., L, vocab_size + L)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape = (B, L) 
        
        # initialise with embedding 
        # probably unnecessary for regression task but not for RHM
        x = self.embed(x)
        # x shape = (B, L, d_initial) for d_initial = vocab_size + seq_len
        
        # attention layers
        for A_i in self.A:
            attn_output = self.attn(x, A_i)
            # attn_output shape = (B, L, n_head, d_prev)
            
            attn_output_flat = attn_output.reshape(*attn_output.shape[:-2], -1)
            # attn_output_flat shape = (B, L, d_new) for d_new = n_head * d_prev
            
            # Concatenate the input 'x' with the new attention features along the last dimension
            x = torch.cat([x, attn_output_flat], dim=-1)
            # x shape = (B, L, d_prev + d_new)

        # Linear layer = (B, L, D_final) @ (D_final, vocab_size) -> (B, L, vocab_size)
        return x @ self.W