<a href="https://colab.research.google.com/github/nischala755/Transformer_Scratch/blob/main/transformer_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Transformer Architecture from Scratch : A Production Ready Deployment**



* Author : Nischala G S
* Github : @nischala755
* Date : 11/06/2025
* Licensed under the Standard MIT License



This Notebook Demonstrates :

---
1. Complete Transformer Implementation from Mathematical Foundations.
2. Advanced Training Techniques.
3. Model Optimizations.
4. ONNX/TensorRT Optimization
5. RLHF with PPO for Production ready deployment.
6. Documentation.


# ***Environment Setup  & Drive Setup ***

In [6]:
# Mount Google Drive for saving models
from google.colab import drive
drive.mount('/content/drive')

# Install required libraries
!pip install torch matplotlib numpy tqdm onnx onnxruntime tensorrt stable-baselines3 gym pyglet

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting tensorrt
  Downloading tensorrt-10.11.0.33.tar.gz (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.7/40.7 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting stable-baselines3
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
Collecting pyglet
  Downloading pyglet-2.1.6-py3-none-any.whl.metadata (7.7 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting tensorrt_cu12==10.11.0.33 (from tensorrt)
  Downloading tensorrt

**2. Download the Shakespeare Dataset from karpathy**

In [7]:
import requests

# Download Tiny Shakespeare dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
with open("shakespeare.txt", "w") as f:
    f.write(response.text)

print("Dataset downloaded.")

Dataset downloaded.


**3.Data Preprocessing **

In [8]:
import torch
import numpy as np

# Load text and create vocabulary
text = open("shakespeare.txt").read()
chars = sorted(list(set(text)))
char_to_idx = {ch:i for i, ch in enumerate(chars)}
idx_to_char = {i:ch for i, ch in enumerate(chars)}

vocab_size = len(chars)

# Hyperparameters
block_size = 64   # context length
batch_size = 64
embed_dim = 128   # embedding dimension
num_heads = 4     # number of attention heads
num_layers = 4    # transformer blocks
dropout = 0.1

def get_batch(split_ratio=0.9):
    data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)
    n = int(split_ratio * len(data))
    train_data = data[:n]
    val_data = data[n:]
    return train_data, val_data

train_data, val_data = get_batch()

def get_mini_batch(data, block_size=block_size, batch_size=batch_size):
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

**4.Model Setup & Training Loop**

In [27]:
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-np.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1)]

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.d_k = embed_dim // num_heads
        self.num_heads = num_heads
        self.qkv = nn.Linear(embed_dim, 3*embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(dropout)
        self.attn_weights = None  # Store attention weights for visualization

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(B, T, self.num_heads, self.d_k).transpose(1, 2), qkv)

        attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.d_k ** 0.5))
        attn = F.softmax(attn, dim=-1)
        self.attn_weights = attn  # Save for visualization
        attn = self.attn_dropout(attn)
        x = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(x)
class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ff = FeedForward(embed_dim)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = PositionalEncoding(embed_dim)
        self.blocks = nn.Sequential(*[Block(embed_dim, num_heads) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_emb(idx)
        pos_emb = self.pos_emb(tok_emb)
        x = self.blocks(pos_emb)
        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss

model = TinyTransformer(vocab_size, embed_dim, num_heads, num_layers)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

Total parameters: 810049


**5. Curriculum Learning**

In [10]:
def get_mini_batch_curriculum(data, block_size, batch_size, epoch, curriculum_steps=5):
    max_block = block_size
    min_block = 8
    step_size = (max_block - min_block) // curriculum_steps
    current_block = min(min_block + step_size * (epoch // 100), max_block)

    ix = torch.randint(len(data) - current_block, (batch_size,))
    x = torch.stack([data[i:i+current_block] for i in ix])
    y = torch.stack([data[i+1:i+current_block+1] for i in ix])
    return x, y, current_block

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

@torch.no_grad()
def estimate_loss():
    model.eval()
    losses = []
    for _ in range(10):
        xb, yb = get_mini_batch(val_data)
        _, loss = model(xb.to(device), yb.to(device))
        losses.append(loss.item())
    model.train()
    return np.mean(losses)

# Training Loop
model.train()
for epoch in range(1000):
    xb, yb, curr_block = get_mini_batch_curriculum(train_data, block_size=64, batch_size=64, epoch=epoch)
    xb, yb = xb.to(device), yb.to(device)

    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Block Size: {curr_block} | Loss: {loss.item():.4f} | Val Loss: {estimate_loss():.4f}")

# Save model
torch.save(model.state_dict(), "/content/drive/MyDrive/tiny_transformer.pth")

Epoch 0 | Block Size: 8 | Loss: 4.3098 | Val Loss: 4.1782
Epoch 100 | Block Size: 19 | Loss: 2.3954 | Val Loss: 2.7609
Epoch 200 | Block Size: 30 | Loss: 1.3629 | Val Loss: 2.2376
Epoch 300 | Block Size: 41 | Loss: 0.9346 | Val Loss: 1.6879
Epoch 400 | Block Size: 52 | Loss: 0.7195 | Val Loss: 1.0953
Epoch 500 | Block Size: 63 | Loss: 0.4719 | Val Loss: 0.4296
Epoch 600 | Block Size: 64 | Loss: 0.0759 | Val Loss: 0.0648
Epoch 700 | Block Size: 64 | Loss: 0.0611 | Val Loss: 0.0510
Epoch 800 | Block Size: 64 | Loss: 0.0526 | Val Loss: 0.0501
Epoch 900 | Block Size: 64 | Loss: 0.0483 | Val Loss: 0.0481


In [12]:
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=2, dropout=0.2)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, targets=None):
        x = self.embed(x)
        x, _ = self.lstm(x)
        logits = self.fc(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1)) if targets is not None else None
        return logits, loss

lstm_model = CharLSTM(vocab_size, 64, 128).to(device)

In [13]:
optimizer_lstm = torch.optim.AdamW(lstm_model.parameters(), lr=3e-4)

for epoch in range(500):
    xb, yb = get_mini_batch(train_data)
    xb, yb = xb.to(device), yb.to(device)

    logits, loss = lstm_model(xb, yb)
    optimizer_lstm.zero_grad()
    loss.backward()
    optimizer_lstm.step()

    if epoch % 100 == 0:
        print(f"LSTM Epoch {epoch} | Loss: {loss.item():.4f}")

torch.save(lstm_model.state_dict(), "/content/drive/MyDrive/lstm_teacher.pth")

LSTM Epoch 0 | Loss: 4.1770
LSTM Epoch 100 | Loss: 3.3679
LSTM Epoch 200 | Loss: 3.0848
LSTM Epoch 300 | Loss: 2.8402
LSTM Epoch 400 | Loss: 2.7230


**7.Knowledge Distillation**

In [16]:
def distillation_loss(student_logits, teacher_logits, targets, alpha=0.5, temperature=2.0):
    # Ensure logits and targets have compatible shapes
    min_seq_len = min(student_logits.size(1), teacher_logits.size(1))

    student_logits = student_logits[:, :min_seq_len, :]
    teacher_logits = teacher_logits[:, :min_seq_len, :]
    targets = targets[:, :min_seq_len]

    # Soft target distillation loss
    soft_labels = F.softmax(teacher_logits / temperature, dim=-1)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    distill_loss = F.kl_div(student_log_probs.view(-1, vocab_size), soft_labels.view(-1, vocab_size), reduction='batchmean') * (temperature ** 2)

    # Task loss (ground truth labels)
    task_loss = F.cross_entropy(student_logits.view(-1, vocab_size), targets.view(-1))

    return alpha * distill_loss + (1 - alpha) * task_loss

In [17]:
# Reset and reinitialize distilled model
distill_model = TinyTransformer(vocab_size, embed_dim, num_heads, num_layers).to(device)
optimizer_distill = torch.optim.AdamW(distill_model.parameters(), lr=3e-4)

# Set to eval mode
lstm_model.eval()

# Train distillation loop
for epoch in range(500):
    xb, yb = get_mini_batch(train_data)
    xb, yb = xb.to(device), yb.to(device)

    with torch.no_grad():
        teacher_logits, _ = lstm_model(xb, yb)  # LSTM output may be [B, T, V]

    student_logits, _ = distill_model(xb, yb)  # Transformer output should also be [B, T, V]

    loss = distillation_loss(student_logits, teacher_logits, yb)

    optimizer_distill.zero_grad()
    loss.backward()
    optimizer_distill.step()

    if epoch % 100 == 0:
        print(f"Distilled Epoch {epoch} | Loss: {loss.item():.4f}")

Distilled Epoch 0 | Loss: 3.2854
Distilled Epoch 100 | Loss: 1.3454
Distilled Epoch 200 | Loss: 1.2860
Distilled Epoch 300 | Loss: 1.2328
Distilled Epoch 400 | Loss: 0.8186


**8.ONNX & TensorRT Optimization for Production Ready**

In [18]:
distill_model.eval()
dummy_input = torch.randint(0, vocab_size, (1, block_size)).to(device)

input_names = ["input_ids"]
output_names = ["logits"]

torch.onnx.export(
    distill_model,
    dummy_input,
    "/content/tiny_transformer.onnx",
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence"},
        "logits": {0: "batch_size", 1: "sequence"}
    }
)
print("ONNX model exported.")

ONNX model exported.


In [19]:
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="/content/tiny_transformer.onnx",
    model_output="/content/tiny_transformer_quantized.onnx",
    weight_type=QuantType.QInt8
)
print("Model quantized successfully.")



Model quantized successfully.


In [None]:
!apt-get update && apt-get install -y libnvinfer8 python3-libnvinfer
!pip install tensorrt

In [None]:
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)

with open("/content/tiny_transformer_quantized.onnx", "rb") as f:
    parser.parse(f.read())

config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.INT8)
engine = builder.build_engine(network, config)

with open("/content/tiny_transformer_trt.engine", "wb") as f:
    f.write(engine.serialize())
print("TensorRT engine built and saved.")

10. INT8 Quantization

In [22]:
# Replace this:
# import gym
# from gym import spaces

# With this:
import gymnasium as gym
from gymnasium import spaces

In [23]:
class TextEnv(gym.Env):
    metadata = {'render_modes': ['human']}

    def __init__(self, model, tokenizer, max_len=32, render_mode=None):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.inv_tokenizer = idx_to_char
        self.vocab_size = len(char_to_idx)
        self.max_len = max_len
        self.action_space = spaces.Discrete(self.vocab_size)
        self.observation_space = spaces.Box(
            low=0,
            high=self.vocab_size - 1,
            shape=(self.max_len,),
            dtype=np.int64
        )
        self.render_mode = render_mode

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.context = [self.tokenizer[' ']] * self.max_len
        return np.array(self.context), {}  # gymnasium requires a second empty info dict

    def step(self, action):
        self.context = self.context[1:] + [action]
        reward = self._get_reward(action)
        terminated = False
        truncated = False
        return np.array(self.context), reward, terminated, truncated, {}

    def _get_reward(self, action):
        # Dummy reward for demonstration; replace with real logic later
        return float(action == 10)  # e.g., reward if specific token is generated

12. RLHF WIth PPO Training

In [24]:
env = TextEnv(distill_model, char_to_idx)
env = DummyVecEnv([lambda: env])

ppo_model = PPO("MlpPolicy", env, verbose=1, n_steps=128)
ppo_model.learn(total_timesteps=10000)
print("PPO training completed.")

Using cpu device
----------------------------
| time/              |     |
|    fps             | 566 |
|    iterations      | 1   |
|    time_elapsed    | 0   |
|    total_timesteps | 128 |
----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 487          |
|    iterations           | 2            |
|    time_elapsed         | 0            |
|    total_timesteps      | 256          |
| train/                  |              |
|    approx_kl            | 0.0084480215 |
|    clip_fraction        | 0.00703      |
|    clip_range           | 0.2          |
|    entropy_loss         | -4.17        |
|    explained_variance   | -1.65        |
|    learning_rate        | 0.0003       |
|    loss                 | -0.0597      |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0536      |
|    value_loss           | 0.244        |
------------------------------------------
------

In [25]:
torch.save(distill_model.state_dict(), "/content/drive/MyDrive/tiny_transformer_final.pth")
torch.save(lstm_model.state_dict(), "/content/drive/MyDrive/lstm_teacher.pth")
print("All models saved to Drive.")

All models saved to Drive.


13. Attention Visualization

In [29]:
import matplotlib.pyplot as plt

def visualize_attention(model, input_seq):
    """
    Visualizes attention maps from each transformer block.

    Args:
        model: Your TinyTransformer model (distill_model)
        input_seq: Tensor of shape [batch_size, seq_len]
    """
    model.eval()
    with torch.no_grad():
        # Get token and positional embeddings
        token_emb = model.token_emb(input_seq)  # [B, T, C]
        pos_emb = model.pos_emb(token_emb)      # [B, T, C]

        # Pass through each transformer block and capture attention maps
        x = pos_emb
        for i, block in enumerate(model.blocks):
            # Save attention map before applying residual connection
            attn = block.attn(x)  # This returns the output; modify Block class if needed

            # HACK: To get raw attention weights, we need to slightly modify MultiHeadAttention
            # See below for updated version that returns attention

            print(f"Block {i+1} attention output shape:", attn.shape)

# Example usage
xb, _ = get_mini_batch(train_data)
visualize_attention(distill_model, xb.to(device))

Block 1 attention output shape: torch.Size([64, 64, 128])
Block 2 attention output shape: torch.Size([64, 64, 128])
Block 3 attention output shape: torch.Size([64, 64, 128])
Block 4 attention output shape: torch.Size([64, 64, 128])
