In [1]:
import os
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class AttentionHead(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        head_dim: int,
        dropout: float
    ):
        super(AttentionHead, self).__init__()

        self.query = nn.Linear(embedding_dim, head_dim)
        self.key = nn.Linear(embedding_dim, head_dim)
        self.value = nn.Linear(embedding_dim, head_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Linear
        # (B, T, E) -> (B, T, H)
        q = self.query(x)
        # (B, T, E) -> (B, T, H)
        k = self.key(x)
        # (B, T, E) -> (B, T, H)
        v = self.value(x)

        _, T, H = k.shape

        # MatMul (Query and Transpose of Key)
        # (B, T, H) @ (B, H, T) -> (B, T, T)
        qk = q @ k.transpose(-1, -2)

        # Scale
        qk = qk / H ** 0.5

        # Mask
        mask = torch.tril(torch.ones(T, T, device=x.device))
        qk = qk.masked_fill(mask == 0, float("-inf"))

        # Softmax
        attention_scores = torch.softmax(qk, dim=-1)

        # Dropout
        attention_scores = self.dropout(attention_scores)

        # MatMul (Attention scores and Value)
        # (B, T, T) @ (B, T, H) -> (B, T, H)
        out = attention_scores @ v

        return out

In [3]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        dropout: float
    ):
        super(MaskedMultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([
            AttentionHead(
                embedding_dim,
                embedding_dim // num_heads,
                dropout
            ) for _ in range(num_heads)
        ])
        
        self.linear = nn.Linear(embedding_dim, embedding_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Scaled dot-product attention and Concat
        # (B, T, E)
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        # Linear
        out = self.linear(out)

        # Dropout
        out = self.dropout(out)

        return out

In [4]:
class FeedForward(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        dropout: float
    ):
        super(FeedForward, self).__init__()

        self.feed_forward_network = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.GELU(),
            nn.Linear(embedding_dim * 4, embedding_dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.feed_forward_network(x))

In [5]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        dropout: float
    ):
        super(DecoderBlock, self).__init__()

        self.multi_head_attention = MaskedMultiHeadAttention(
            embedding_dim,
            num_heads,
            dropout
        )
        self.layer_norm1 = nn.LayerNorm(embedding_dim)

        self.feed_forward = FeedForward(
            embedding_dim,
            dropout
        )
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        # Norm and Masked multi-head attention
        out = self.multi_head_attention(self.layer_norm1(x))
        # Add
        x = x + out

        # Norm and Feed forward
        out = self.feed_forward(self.layer_norm2(x))
        # Add
        out = x + out

        return out

In [6]:
class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_seq_len: int,
        embedding_dim: int,
        num_heads: int,
        num_layers: int,
        dropout: float
    ):
        super(Transformer, self).__init__()

        self.token_embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim
        )
        self.position_embedding = nn.Embedding(
            num_embeddings=max_seq_len,
            embedding_dim=embedding_dim
        )
        self.dropout = nn.Dropout(dropout)

        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(
                embedding_dim,
                num_heads,
                dropout
            ) for _ in range(num_layers)
        ])

        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, ids):
        token_embedding = self.token_embedding(ids)
        positional_encoding = self.position_embedding(torch.arange(0, ids.shape[-1], device=ids.device))
        x = token_embedding + positional_encoding
        x = self.dropout(x)

        for block in self.decoder_blocks:
            x = block(x)

        x = self.layer_norm(x)
        logits = self.lm_head(x)

        return logits

In [7]:
DATASET_PATH = r"D:\Datasets\Tiny-Shakespeare\All.txt"
SPECIAL_TOKENS = ["<UNK>"]
TRAIN_SPLIT = 0.9
EMBEDDING_DIM = 384
NUM_HEADS = 6
NUM_LAYERS = 6
DROPOUT = 0.2
LEARNING_RATE = 3e-4
BATCH_SIZE = 64
SEQUENCE_LENGTH = 256
TRAINING_ITERS = 5000
EVAL_FREQ = 500
EVAL_ITERS = 100
OUTPUT_DIR = "Models"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
with open(DATASET_PATH, "r", encoding="utf-8", errors="replace") as f:
    text = f.read()

In [9]:
chars = sorted(list(set(text))) + SPECIAL_TOKENS

vocab_size = len(chars)

char2id = {ch: i for i, ch in enumerate(chars)}
id2char = {i: ch for ch, i in char2id.items()}

string2ids = lambda s: [char2id.get(ch, char2id["<UNK>"]) for ch in s]
ids2string = lambda l: "".join([id2char[i] for i in l])

In [10]:
data = torch.tensor(string2ids(text), dtype=torch.int64)

train_end_index = int(len(data) * TRAIN_SPLIT)

train_data = data[:train_end_index]
eval_data = data[train_end_index:]

In [None]:
def get_batch(
    split: str
) -> tuple[torch.Tensor, torch.Tensor]:
    data = train_data if split == "train" else eval_data

    # Starting index of each batch
    start_ids = torch.randint(low=0, high=len(data) - SEQUENCE_LENGTH, size=(BATCH_SIZE,))

    x = torch.stack([data[i:i + SEQUENCE_LENGTH] for i in start_ids]).to(DEVICE)
    y = torch.stack([data[i + 1:i + SEQUENCE_LENGTH + 1] for i in start_ids]).to(DEVICE)

    return x, y

In [None]:
def save_model(
    model: Transformer,
    save_dir: str = OUTPUT_DIR
) -> None:
    # Create directory if it doesnt exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Save model weights to model_weights.pth
    torch.save(model.state_dict(), f=os.path.join(save_dir, "model_weights.pth"))

    # Save vocabulary (character to id mapping) to vocab.pkl
    with open(os.path.join(save_dir, "vocab.pkl"), "wb") as f:
        pickle.dump(char2id, f)

In [None]:
@torch.no_grad()
def evaluate(
    model: Transformer
) -> dict[str, float]:
    model.eval()

    out = {}
    for split in ["train", "eval"]:
        # Tensor to store the loss from each batch
        losses = torch.zeros(EVAL_ITERS)

        for i in range(EVAL_ITERS):
            # Get batch of data
            x, y = get_batch(split)

            # Send to GPU
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            # Forward pass
            logits = model(x)
            loss = F.cross_entropy(
                logits.reshape(-1, vocab_size),
                y.reshape(-1)
            )

            # Store the loss
            losses[i] = loss.item()

        out[split] = losses.mean().item()

    model.train()

    return out

In [None]:
def train(
    model: Transformer,
    optimizer: torch.optim.Optimizer
) -> None:
    best_val_loss = float("inf")
    best_val_loss_step = None

    for step in tqdm(range(TRAINING_ITERS)):
        # Get batch of data
        x, y = get_batch(split="train")

        # Send to GPU
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Forward pass
        logits = model(x)
        loss = F.cross_entropy(
            logits.reshape(-1, vocab_size),
            y.reshape(-1)
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % EVAL_FREQ == 0:
            # Evaluate the model
            losses = evaluate(model)

            print(f"Step: {step + 1} - Training loss: {losses['train']:.4f} - Validation loss: {losses['eval']:.4f}")

            if losses["eval"] < best_val_loss:
                best_val_loss = losses["eval"]
                best_val_loss_step = step + 1

            # Save model checkpoint
            save_model(model, save_dir=os.path.join(OUTPUT_DIR, f"checkpoint-{step + 1}"))

    if best_val_loss_step is not None:
        model.load_state_dict(torch.load(
            os.path.join(OUTPUT_DIR, f"checkpoint-{best_val_loss_step}", "model_weights.pth"),
            weights_only=True
        ))

In [None]:
@torch.no_grad()
def generate(
    model: Transformer,
    prompt: str = "\n",
    max_new_tokens: int = 1000,
    do_sample: bool = False
) -> str:
    assert prompt != "", \
        "Cannot have an empty string as the prompt. Please specify a valid prompt or leave the default prompt."
    
    model.eval()
    
    # Tokenize the prompt, convert to pytorch tensor, send to GPU and add batch dimension
    ids = torch.tensor(
        string2ids(prompt),
        dtype=torch.int64,
        device=DEVICE
    ).unsqueeze(dim=0) # (B, t)
    # Number of tokens in the prompt
    prompt_len = ids.shape[1]

    for _ in range(max_new_tokens):
        # The model's input will be the last SEQUENCE_LENGTH tokens
        x = ids[:, -SEQUENCE_LENGTH:] # (B, T)
        # Get model prediction
        logits = model(x)[:, -1, :] # (B, C)        
        if do_sample:
            # Get probability distribution
            probs = torch.softmax(logits, dim=-1) # (B, C)
            # Sample from the distribution
            idx = torch.multinomial(probs, num_samples=1) # (B, 1)
        else:
            # Get the index with the highest probability
            idx = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)
        # Append sampled index to the running sequence
        ids = torch.cat((ids, idx), dim=-1) # (B, t + 1)

    # Get only the tokens generated by the model
    response_ids = ids.squeeze(dim=0).cpu().tolist()[prompt_len:]
    # Convert token ids to characters
    response = ids2string(response_ids)

    model.train()

    return response

In [16]:
model = Transformer(
    vocab_size,
    SEQUENCE_LENGTH,
    EMBEDDING_DIM,
    NUM_HEADS,
    NUM_LAYERS,
    DROPOUT
).to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [17]:
print(evaluate(model))

{'train': 4.318239688873291, 'eval': 4.3184614181518555}


In [18]:
train(model, optimizer)

 10%|█         | 501/5000 [02:20<4:10:29,  3.34s/it]

Step: 500 - Training loss: 1.9816 - Validation loss: 2.0719


 20%|██        | 1001/5000 [04:41<3:43:17,  3.35s/it]

Step: 1000 - Training loss: 1.6146 - Validation loss: 1.7811


 30%|███       | 1501/5000 [07:01<3:16:21,  3.37s/it]

Step: 1500 - Training loss: 1.4450 - Validation loss: 1.6429


 40%|████      | 2001/5000 [09:22<2:48:02,  3.36s/it]

Step: 2000 - Training loss: 1.3456 - Validation loss: 1.5681


 50%|█████     | 2501/5000 [11:43<2:19:58,  3.36s/it]

Step: 2500 - Training loss: 1.2778 - Validation loss: 1.5135


 60%|██████    | 3001/5000 [14:04<1:51:55,  3.36s/it]

Step: 3000 - Training loss: 1.2326 - Validation loss: 1.5004


 70%|███████   | 3501/5000 [16:25<1:23:56,  3.36s/it]

Step: 3500 - Training loss: 1.1866 - Validation loss: 1.4788


 80%|████████  | 4001/5000 [18:46<55:53,  3.36s/it]  

Step: 4000 - Training loss: 1.1484 - Validation loss: 1.4813


 90%|█████████ | 4501/5000 [21:07<27:55,  3.36s/it]

Step: 4500 - Training loss: 1.1147 - Validation loss: 1.4698


100%|██████████| 5000/5000 [23:28<00:00,  3.55it/s]

Step: 5000 - Training loss: 1.0765 - Validation loss: 1.4713





In [19]:
print(evaluate(model))

{'train': 1.1103941202163696, 'eval': 1.472131609916687}


In [20]:
save_model(model)

In [23]:
print(generate(model, max_new_tokens=5000, do_sample=True))

Are saily danger than my fellow.
Pitch'd I am for Lady Dorset the limb,
Which mercy you are as I set a rooman
As Edward as Wiltshires, Pompey; and you part,
Your brother, and friends me you so;
If therefore, sir, no purpose, yield you hare
Marcius fin, may be you to green to part?
Nightful talk I do a prisoner!
I truth Camillo, and know me?
Therefore life to his fair queen again.

BENVOLIO:
Soft of the sword male pernive your groans.

ROMEO:
I know, my lord, I dry no more torch, sir, but all,
that Bolingbroke's for usurping in his city,
As guilty shall stol her beast gentlemen.
What were man? why? In heart's in't not in the paid,
How cut our affect with husbandful friend,
That he shall scarce pay Tybalt discreets?
O prayer lords, and with some make digs!
How down, frown friend, I dress to myself.
Ah! What thou dost away?

HERMIONE:
The catch hath the queen made be my power?
Thou wast as I that my love? take my cousin's life,
Read and quickrices, lodger bear thy love:
The pembraction ha

In [24]:
print(generate(
    model,
    prompt="Enter two rivals at dawn.\n\nFirst Rival:",
    do_sample=True
))


No; good Montagues; I for thee have Edward in thy service,
By God's, unwilling name to be particular,
Or what they seems for roaring peace doth between
To let us upon the wrongful tire wars,
That stirrings in any high of circumsal eye,
Which hare in his dafferent royalty
Such treason with that middle envious sword:
Hark you told him, in his strength.

POLIXENES:
Ay, if the duty's approach.
But the tidil bend of his princesss
Richmond and succession of pipe itself,
That drums did not till still with his place
That usurprishes be drink when the numbers
Which air effections brooks up as joint
The old which storm way or your counsel?

JULIET:
There is a pair that you have snow.
I'll brought you go along. What Cominius is most he,
Thou pass'd my heart, your favours will be gone?

ROMEO:
I will convise in thy people?
O, breathe your world's resolute envious.
Edward, desert thou welcome?

AEdile:
O, part, but not that I will stand:
The horse: at proffendon serves no inherappy sun
So the king