In [19]:
#define a MLP Mixer based causal-language-model using weight masking

import torch
import torch.nn as nn
from einops import rearrange

class CausalLinear(nn.Module):
    """
    A linear layer with a lower-triangular (causal) mask applied to the weight matrix.
    This ensures each position i cannot use info from positions > i.
    """
    def __init__(self, dim):
        super().__init__()

        # Standard weight + bias
        self.weight = nn.Parameter(torch.randn(dim, dim))
        self.bias = nn.Parameter(torch.zeros(dim))

        # Lower-triangular mask
        mask = torch.tril(torch.ones(dim, dim))
        self.register_buffer('mask', mask)

    def forward(self, x):
        """
        x shape: (batch, embed_dim, seq_len)
        """
        B, E, S = x.shape
        W = self.weight * self.mask    # elementwise multiply
        x_reshaped = x.view(B * E, S)  # (B*E, S)
        out = x_reshaped @ W           # (B*E, S)
        out = out + self.bias          # broadcast bias
        out = out.view(B, E, S)        # reshape back

        return out

class MixerBlock(nn.Module):
    
    def __init__(
        self,
        hidden_dim:int,
        seq_len:int,
        expansion_factor:int=2):

        super(MixerBlock, self).__init__()

        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.expansion_factor = expansion_factor

        #channel-mixing layer
        self.channel_mixing_layer = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * expansion_factor),
            nn.GELU(),
            nn.Linear(hidden_dim * expansion_factor, hidden_dim)
        )

        #token-mixing layer
        self.token_mixing_layer = nn.Sequential(
            nn.LayerNorm(seq_len),
            CausalLinear(seq_len),
            nn.GELU(),
            CausalLinear(seq_len)
        )

    def forward(self, x):

        x = x + self.channel_mixing_layer(x)
        x = x.transpose(1, 2)
        x = x + self.token_mixing_layer(x)
        x = x.transpose(1, 2)

        return x

class MLPMixer(nn.Module):
    
    def __init__(
        self,
        vocab_size:int,
        hidden_dim:int,
        seq_len:int,
        num_blocks:int):

        super(MLPMixer, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.num_blocks = num_blocks

        # Input Embedding
        self.input_layer = nn.Embedding(vocab_size, hidden_dim)

        # Mixer Blocks
        self.mixer_blocks = nn.ModuleList(
            [MixerBlock(hidden_dim, seq_len) for _ in range(num_blocks)]
        )

        # Output Layer
        self.output_layer = nn.Linear(hidden_dim, vocab_size, bias=False)

        # Tie input and output layer weights
        self.output_layer.weight = self.input_layer.weight

        # Initialize weights
        self._init_weights()

        # Define loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def _init_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, CausalLinear):
                # Kaiming He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x, labels=None):

        x = self.input_layer(x)
        for block in self.mixer_blocks:
            x = block(x)
        output = self.output_layer(x)

        if not labels is None:

            print('ping')

            shift_logits = output[:,:-1,:]
            shift_logits = shift_logits.contiguous().view(-1, self.vocab_size)
            shift_labels = labels[:,1:]
            shift_labels = shift_labels.contiguous().view(-1)
            loss = self.loss_fn(shift_logits, shift_labels)
            return loss, output

        else:
            return output

In [22]:
model = MLPMixer(
    vocab_size=5,
    hidden_dim=7,
    seq_len=9,
    num_blocks=1
)

# x = torch.randint(0, 5, (2, 9))
x = torch.tensor([[0, 1, 2, 3, 4, 0, 1, 2, 3], [0, 1, 2, 3, 4, 0, 1, 2, 3]])
print(x.shape)
output = model(x, x)
print(output)

torch.Size([2, 9])
ping
(tensor(2.4437, grad_fn=<NllLossBackward0>), tensor([[[ 1.8216,  2.2989,  2.5968, -0.2042,  2.8375],
         [ 1.7342,  4.1020,  1.7495, -0.7884,  0.4396],
         [ 0.0629,  1.3846,  0.4511,  0.6950,  1.3095],
         [-2.2251, -1.2382, -3.0345,  0.5811, -3.1179],
         [-0.9777,  0.1485, -1.1675,  1.7963,  1.4541],
         [ 2.7091,  3.0780,  3.6391, -0.6576,  3.4016],
         [ 0.5278,  3.0303,  0.4562, -0.5980, -1.1697],
         [-0.0365,  1.0025,  0.8301,  0.6095,  1.4044],
         [-2.1210, -2.2493, -2.7673,  1.4604, -1.1105]],

        [[ 1.8216,  2.2989,  2.5968, -0.2042,  2.8375],
         [ 1.7342,  4.1020,  1.7495, -0.7884,  0.4396],
         [ 0.0629,  1.3846,  0.4511,  0.6950,  1.3095],
         [-2.2251, -1.2382, -3.0345,  0.5811, -3.1179],
         [-0.9777,  0.1485, -1.1675,  1.7963,  1.4541],
         [ 2.7091,  3.0780,  3.6391, -0.6576,  3.4016],
         [ 0.5278,  3.0303,  0.4562, -0.5980, -1.1697],
         [-0.0365,  1.0025,  0.83