# Simple Transformer from Scratch

### Authors:
 - Carla Ellefsen
 - Brendan McKinley
 - Diya Vinod
 - Bingshen Lu
 - Michael Ivanitskiy

In [None]:
from dataclasses import dataclass
from pathlib import Path

import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from jaxtyping import Float, Int
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Optional, Tuple, List
import matplotlib.pyplot as plt
import re
import os as os

## Transformer Architecture

In [None]:
@dataclass
class GPTConfig:
    # default test values -- too small for a real language model, but big enough for testing
    d_vocab: int = 10_000
    d_model: int = 128
    d_mlp: int = 512
    n_heads: int = 4
    d_head: int = 32
    n_layers: int = 6
    act_fn: type[nn.Module] = nn.ReLU

    @property
    def n_params(self) -> tuple[int]:
        "an estimate of the number of parameters"
        return (
            self.d_vocab * self.d_model  # embeddings (and tied unembeddings)
            + (
                    self.d_model * self.d_mlp * 2  # mlp weights
                    + self.d_model + self.d_mlp  # mlp bias
                    + self.n_heads * (  # number of heads
                            4 * self.d_model * self.d_head  # 4 because Q, K, O, V
                    )
            ) * self.n_layers,  # for each layer
        )
class AttentionHead(nn.Module):
    def __init__(self, cfg: GPTConfig):
        print("Attention Head Constructor...")
        super().__init__()
        self.relu = nn.ReLU()
        self.d_vocab = cfg.d_vocab
        self.d_model = cfg.d_model
        self.d_head = cfg.d_head
        self.wq = nn.Linear(self.d_model, self.d_head)
        self.wk = nn.Linear(self.d_model, self.d_head)
        self.wv = nn.Linear(self.d_model, self.d_head)
        self.wo = nn.Linear(self.d_head, self.d_model)

    def forward(self,
                x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
        def masking_matrix(n_context):
            mask = torch.zeros((n_context, n_context))  # Start with all 0s
            mask[torch.triu(torch.ones((n_context, n_context)), diagonal=1) == 1] = -float('inf')  # Set above diagonal to -inf
            return mask

        M = masking_matrix(x.shape[0])
        # softmax_argument = x*self.wq*torch.transpose(self.wk)*torch.transpose(x) + M
        # wk_out = torch.transpose(self.wk(x), 0, 1)
        wk_out = self.wk(x).transpose(-2, -1)  # Correct transposition

        # wk_out = self.wk(x).transpose(-2, -1)
        # print("WK shape ", wk_out.shape)
        # print("WK shape: torch.transpose(self.wk(x), 0, 1) ", torch.transpose(self.wk(x), 0, 1).shape)
        # print("WK shape: self.wk(x).transpose(-2, -1) ", self.wk(x).transpose(-2, -1).shape)

        wq_out = self.wq(x)
        # print("WQ shape ", wq_out.shape)
        softmax_out = F.softmax((wq_out @ wk_out + M), dim=-1)

        # print("Softmax shape ", softmax_out.shape)
        wv_out = self.wv(x)
        # print("WV shape ", wv_out.shape)
        wo_out = self.wo(wv_out)

        result = softmax_out @ wo_out
        # print("Final A Shape ", result.shape)
        return result



(2463488,)


## Some simple testing
Ensure the code does not crash and shapes are as expected. 

In [None]:
gpt_config = GPTConfig()
gpt = Transformer(gpt_config)
x = torch.randint(0, gpt_config.d_vocab, (12,))
print(x.shape)
print(gpt(x).shape)

torch.Size([12])
torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
torch.Size([12, 10000])


In [None]:
# Attention Testing
# gpt_config = GPTConfig()
# attn_head = AttentionHead(gpt_config)
# x = torch.randn(256, gpt_config.d_model)
# print(x)
# print(x.shape)
# attn_head.forward(x)
# multi_head = MultiHeadedAttention(gpt_config)
# multi_head.forward(x).shape

## Training the Transformer

In [None]:
import re
some_text = """
In reality, of course, we don't construct such chains explicitly, but instead we want them to learn from data.

To put something in a markov chain or neural network, we need to turn it into numbers. this is straightforward for images: each pixel is already a number! 

In computers, text is stored as a sequence of numbers. Our neural network, in principle, can learn to predict the next number in the sequence. However, each number usually represents a single letter, or even just part of a letter. what do you think happens when we throw something like this into a markov chain?
"""

def create_word_index(text):
    words = re.findall(r'\b\w+\b', text.lower())
    sorted_words = sorted(set(words))
    word_to_index = {word: idx for idx, word in enumerate(sorted_words)}

    return word_to_index

def text_to_tensor(vocab_dict, text):
    # Remove punctuation and tokenize words
    words = re.findall(r'\b\w+\b', text.lower())

    # Convert words to their corresponding integer indices
    int_sequence = [vocab_dict[word] for word in words if word in vocab_dict]

    # Convert list to a PyTorch tensor
    return torch.tensor(int_sequence, dtype=torch.long)

print(text_to_tensor(create_word_index(some_text), some_text))

tensor([21, 45, 37,  9, 64, 12, 53,  8, 52,  6, 15,  3, 22, 64, 63, 56, 60, 27,
        17, 10, 60, 44, 49, 21,  0, 30,  5, 38, 33, 32, 64, 31, 60, 61, 25, 23,
        36, 58, 24, 51, 16, 20, 13, 41, 24,  1,  0, 35, 21,  7, 54, 24, 50,  2,
         0, 47, 37, 36, 39, 33, 32, 21, 43,  4, 27, 60, 42, 55, 34, 35, 21, 55,
        47, 19, 13, 35, 62, 46,  0, 48, 28, 38, 14, 26, 40, 37,  0, 28, 65, 11,
        67, 57, 18, 66, 64, 59, 49, 29, 58, 23,  0, 30,  5])
