## Problem: Write a Transformer

### Problem Statement
Implement a **Transformer model** in PyTorch by completing the required sections. The model should consist of an embedding layer, a Transformer encoder, and an output layer for sequence processing and prediction.

### Requirements
1. **Define the Transformer Model Architecture**:
   - **Embedding Layer**:
     - Implement a layer to transform input data into a higher-dimensional space.
     - Use a `torch.nn.Linear` or `torch.nn.Embedding` layer to create embeddings from the input.
   - **Transformer Encoder**:
     - Use `torch.nn.TransformerEncoder` or `torch.nn.Transformer` to process sequences with attention.
     - Configure parameters such as the number of attention heads and encoder layers.
   - **Output Layer**:
     - Add a fully connected (linear) layer to reduce the transformer's sequence output into the desired output dimension.

2. **Implement the Forward Method**:
   - Map the input to the higher-dimensional space using the embedding layer.
   - Pass the transformed input through the Transformer encoder.
   - Use the output layer to convert the encoded sequence into predictions.

### Constraints
- Handle input padding correctly for variable-length sequences.
- Ensure compatibility with batch processing by correctly shaping input and output tensors.


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass
import math
import torch.nn.functional as F

In [15]:

@dataclass
class GPTConfig:
    vocab_size: int
    max_seq_len: int = 256
    n_layers: int = 6
    n_heads: int = 8
    embed_dim: int = 512      # d_model
    d_ff: int = 2048        # for SwiGLU consider ~(8/3)*emb_dim
    dropout: float = 0.1
    bias: bool = False
    tie_weights: bool = True
    use_classifier: bool = False
    num_classes: int = 2


In [16]:
# multi head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, config):


      self.q = nn.Linear(config.embed_dim, config.embed_dim)
      self.k = nn.Linear(config.embed_dim, config.embed_dim)
      self.v = nn.Linear(config.embed_dim, config.embed_dim)
      self.output = nn.Linear(config.embed_dim, config.embed_dim)

      self.n_head = config.n_heads

      self.head_dim = config.embed_dim // config.n_heads

      self.embed_dim = config.embed_dim

      self.scalar = 1.0/math.sqrt(config.head_dim)

      mask = torch.tri(torch.ones(self.embed_dim, self.embed_dim))
      self.register_buff("casual_mask", mask.view(1, 1, self.embed_dim, self.embed_dim), persistent=False)

    def forward(self, x):
      B, seq_len, d_model = x.shape()

      q = self.q(x)
      k = self.k(x)
      v = self.v(x)

      q = q.view(B, seq_len, self.n_head,self.head_dim).transpose(1,2)
      k = k.view(B, seq_len, self.n_head,self.head_dim).transpose(1,2)
      v = v.view(B, seq_len, self.n_head,self.head_dim).transpose(1,2)


      att = (q @ k.transpose(-1, -2))*self.scalar

      mask = self.casual_mask[:,:, :seq_len,:seq_len]

      att = att.masked_fill(~mask, float("-inf"))

      att = F.softmax(att, dim=-1)

      y = att@v
      y = y.transpose(1,2).contingous().view(B, seq_len, self.embed_dim)
      y = self.output(y)
      y = self.droput(y)

      return y

In [17]:
class KVCache:
  def __init__(self):
    self.k_cache = None
    self.v_cache = None

  def update(self, new_k, new_v):

    if self.k_cache is None:
      self.k_cache = new_k
      self.v_cache = new_v
    else:
      self.k_cache = torch.cat([self.k_cache, new_k], dim=2)
      self.v_cache = torch.cat([self.v_cache, new_v], dim=2)
    return self.k_cache, self.v_cache


  def clear(self):
    self.k_cache = None
    self.v_cache = None

  def getSize(self):
    if self.k_cache:
      return self.k_cache.size(2)
    else:
      return 0

  def clap_to_max_len(self, max_len):
    if self.k_cache is None:
      return
    if self.k_cache.size(2) > max_len:
      self.k_cache = self.k_cache[:, :, -max_len:, :]
      self.v_cache = self.v_cache[:, :, -max_len:, :]
      return


In [18]:
class MLP:
  def __init__(self, config):
    super().__init__()
    self.c = nn.Linear(config.embed_dim, 4*config.embed_dim)
    self.gelu = nn.GELU()
    self.c2 = nn.Linear(4*config.embed_dim, config.embed_dim)
    self.dropout = nn.Dropout()

  def forward(self, x):
    x = self.c(x)
    x = self.gelu(x)
    x = self.c2(x)
    x = self.dropout(x)

    return x


In [19]:
class BLOCK:
  def __init__(self, config):
    super().__init__()
    self.att = MultiHeadAttention(config)
    self.n1 = nn.RMSNorm(config.emb_dim)
    self.n2 = nn.RMSNorm(config.emb_dim)
    self.mlp = MLP(config)

  def forward(self, x):
    x = x + self.att(self.n1(x))
    x = x + self.mlp(self.n2(x))

    return x


In [None]:
# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

Epoch [100/1000], Loss: 1.5771
Epoch [200/1000], Loss: 0.8907
Epoch [300/1000], Loss: 0.6074
Epoch [400/1000], Loss: 0.3587
Epoch [500/1000], Loss: 0.1986
Epoch [600/1000], Loss: 0.1157
Epoch [700/1000], Loss: 0.0762
Epoch [800/1000], Loss: 0.0629
Epoch [900/1000], Loss: 0.0575
Epoch [1000/1000], Loss: 0.0379


In [None]:
# Testing on new data
X_test = torch.rand(2, seq_length, input_dim)
with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

Predictions for [[[0.6648573279380798], [0.6041934490203857], [0.3187063932418823], [0.9813531041145325], [0.09837877750396729], [0.3223891258239746], [0.3124500513076782], [0.36122316122055054], [0.8705818057060242], [0.4751177430152893]], [[0.569571316242218], [0.05407053232192993], [0.16180634498596191], [0.8140731453895569], [0.34717607498168945], [0.6788632273674011], [0.11463749408721924], [0.21608346700668335], [0.7405895590782166], [0.8521053194999695]]]: [[5.141801834106445], [5.020108699798584]]
