In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
from dataclasses import dataclass

In [15]:
class MLP(nn.Module):
    """Simple multi-layer perceptron with two linear layers and a relu non-linearity in between"""
    def __init__(self, n_embed: int, bias: bool = False):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features=n_embed, out_features=4 * n_embed, bias=bias),
            nn.ReLU(),
            nn.Linear(in_features=4 * n_embed, out_features=n_embed, bias=bias),
            #nn.Dropout(dropout)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.mlp(x)

class DotProductAttention(nn.Module):
    def __init__(self, n_embed: int, bias: bool=False):
        super().__init__()
        self.w_k = nn.Linear(in_features=n_embed, out_features=n_embed, bias=bias)
        self.q_k = nn.Linear(in_features=n_embed, out_features=n_embed, bias=bias)
        self.v_k = nn.Linear(in_features=n_embed, out_features=n_embed, bias=bias)

    def forward(self, x: torch.tensor) -> torch.tensor:
        _,n_embed = x.size()
        
        k = self.w_k(x)
        q = self.w_q(x)
        v = self.w_v(x)

        attention = (k @ q.T) / math.sqrt(n_embed)

        return F.softmax(attention) @ v
        
        
class TransformerBlock(nn.Module):
    """Transformer block that combines attention and MLP, both with pre-layernorm and residual connections"""
    def __init__(self, n_embed: int, bias: bool = False):
        super().__init__()
        self.attention = nn.Sequential(
            nn.LayerNorm(n_embed, bias=bias),
            DotProductAttention(n_embed=n_embed, bias=bias)
        )
        self.projection = nn.Sequential(
            nn.LayerNorm(n_embed, bias=bias),
            MLP(n_embed=n_embed, bias=bias)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        attn = x + self.attention(x)
        proj = attn + self.projection(attn)
        return proj


@dataclass
class TransformerConfig:
    block_size: int = None
    vocab_size: int = None
    n_layers: int = 4
    n_embed: int = 12
    block_size: int = 64

class NaiveTransformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            token_embed = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.n_embed),
            pos_embed = nn.Embedding(num_embeddings=config.block_size, embedding_dim=config.n_embed),
            attention_blocks = nn.ModuleList([TransformerBlock(n_embed=config.n_embed) for _ in range(config.n_layers)]),
        ))
        self.output_projection = nn.Linear(in_features=config.n_embed, out_features=config.vocab_size, bias=False)

config = TransformerConfig(block_size=64, vocab_size=65)
print(config)
NaiveTransformer(config)

TransformerConfig(block_size=64, vocab_size=65, n_layers=4, n_embed=12)


NaiveTransformer(
  (transformer): ModuleDict(
    (token_embed): Embedding(65, 12)
    (pos_embed): Embedding(64, 12)
    (attention_blocks): ModuleList(
      (0-3): 4 x TransformerBlock(
        (attention): Sequential(
          (0): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
          (1): DotProductAttention(
            (w_k): Linear(in_features=12, out_features=12, bias=False)
            (q_k): Linear(in_features=12, out_features=12, bias=False)
            (v_k): Linear(in_features=12, out_features=12, bias=False)
          )
        )
        (projection): Sequential(
          (0): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
          (1): MLP(
            (mlp): Sequential(
              (0): Linear(in_features=12, out_features=48, bias=False)
              (1): ReLU()
              (2): Linear(in_features=48, out_features=12, bias=False)
            )
          )
        )
      )
    )
  )
  (output_projection): Linear(in_features=12, out_features=65, 