In [None]:
# Imports
from typing import List, Tuple, Optional, Dict, Type

import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F

## Reading and Exploring the Data

In [None]:
with open('ApoteosisWhatsAppChat.txt', 'r') as f:
    input_text = f.read()

print(input_text[:1000])

In [None]:
# Get the unique characters in the text
chars_in_text = sorted(list(set(input_text)))
vocab_size = len(chars_in_text)
display(''.join(chars_in_text))
display(f'Total unique characters: {vocab_size}')

## Map from chars to int, and vice versa

In [None]:
char_to_int = {}
int_to_char = {}

for i, c in enumerate(chars_in_text):
    char_to_int[c] = i
    int_to_char[i] = c


## Functions to encode and decode characters

In [None]:
def encode(text: str) -> List[int]:
    return [char_to_int[c] for c in text]

def decode(encoded_text: List[int]) -> str:
    return ''.join([int_to_char[i] for i in encoded_text])

# Test them
encoded_text = encode('Mi mama me mima')
print(encoded_text)

print(decode(encoded_text))


## Encode the entire text and put it in a tensor


In [None]:
text_tensor = torch.LongTensor(encode(input_text))
display(text_tensor.shape)
display(text_tensor[:100])

## Datasets

Using 90% of the text as training, the rest as validation.


In [None]:
split_point = int(0.9 * len(text_tensor))
train_data = text_tensor[:split_point]
validation_data = text_tensor[split_point:]

display(train_data[:100])
display(validation_data[:100])


In [None]:
from enum import Enum
# Make results reproduceable
torch.manual_seed(8888)
batch_size = 256
# Maximum context length for predictions
context_size = 256
learning_rate = 3e-4

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class SplitType(Enum):
    train = 'TRAIN'
    validation = 'VALIDATION'

In [None]:
def get_batch(split_type: SplitType) -> Tuple[torch.Tensor, torch.Tensor]:
    working_data = train_data if split_type == SplitType.train else validation_data
    indices = torch.randint(len(working_data) - context_size, (batch_size, ))
    inputs = torch.stack([working_data[i : i + context_size] for i in indices])
    targets = torch.stack([working_data[i + 1 : i + context_size + 1] for i in indices])
    
    inputs, targets = inputs.to(device), targets.to(device)
    return inputs, targets

In [None]:
@torch.no_grad()
def estimate_loss(model: Type[nn.Module], eval_iterations: int) -> Dict[float, float]:
    estimated_loss = {}
    model.eval()

    for split_type in SplitType:
        split_loss = torch.zeros(eval_iterations)
        for k in range(eval_iterations):
            inputs, targets = get_batch(split_type)
            _, loss = model(inputs, targets)
            split_loss[k] = loss.item()
        
        estimated_loss[split_type] = split_loss.mean()

    model.train()

    return estimated_loss


In [None]:
# First try with Bigrams

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        # Embedding table
        self.embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(
            self,
            inputs: torch.Tensor,
            targets: Optional[torch.Tensor] = None
            ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

        logits = self.embedding_table(inputs)

        loss = None

        if targets is not None:
            # Reshape as Pytorch expects
            b, t, c = logits.shape
            logits = logits.view(b*t, c)
            targets = targets.view(b*t)

            loss = F.cross_entropy(logits, targets)

        # The return value is (batch_size, context_size, vocab_size)
        return logits, loss
    
    def generate(self, inputs: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        """
        Generates up to max_new_tokens of text by predicting what comes after the inputs
        """
        prediction_appended = inputs
        for _ in range(max_new_tokens):
            logits, _ = self(prediction_appended)
            # We are interested only in the last time step
            logits = logits[:, -1, :]
            probabilities = F.softmax(logits, dim=-1)
            # Sample
            next_token = torch.multinomial(probabilities, num_samples=1)
            # Append to the current input
            prediction_appended = torch.cat((prediction_appended, next_token), dim=1)

        return prediction_appended

In [None]:
bigram_language_model = BigramLanguageModel(vocab_size).to(device)
inputs, targets = get_batch(SplitType.train)
output, loss = bigram_language_model(inputs, targets)
display(output.shape, loss)
display(f'Expected loss: {-np.log(1/vocab_size)}')

In [None]:
def generate(model: Type[nn.Module], max_new_tokens: int = 100) -> str:
    inputs = torch.ones((1, 1), dtype=torch.long).to(device)
    encoded_generated_text = model.generate(inputs, max_new_tokens=max_new_tokens)
    return decode(encoded_generated_text[0].tolist())

In [None]:
# Generate!
generate(bigram_language_model)

## Train the model

In [None]:
from tqdm.notebook import tqdm_notebook
optimizer = torch.optim.AdamW(bigram_language_model.parameters(), lr=1e-3)

max_iterations = 10000
eval_iterations = 1000

for iteration in tqdm_notebook(range(max_iterations)):

    if iteration % eval_iterations == 0:
        evaluated_loss = estimate_loss(bigram_language_model, eval_iterations)
        display(f'Step {iteration}: Train loss: {evaluated_loss[SplitType.train]}, Validation loss: {evaluated_loss[SplitType.validation]}')

    inputs, targets = get_batch(SplitType.train)

    logits, loss = bigram_language_model(inputs, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

   

In [None]:
generate(bigram_language_model)

This is pretty much the limit of the model without tinkering anything else.

Now, let's plug attention

In [None]:
ATTENTION_HEAD_DIMENSIONS = 384
dropout_rate = 0.2

class AttentionHead(nn.Module):

    def __init__(self, n_embedding_dimensions: int, head_dimensions: int):
        super().__init__()
        self.key = nn.Linear(n_embedding_dimensions, head_dimensions)
        self.query = nn.Linear(n_embedding_dimensions, head_dimensions)
        self.value = nn.Linear(n_embedding_dimensions, head_dimensions)

        # Since this is not a parameter, it's registered as a buffer
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        batch_size, time_pos, dimensions = inputs.shape
        key = self.key(inputs)
        query = self.query(inputs)

        # Calculating affinities
        affinities = query @ key.transpose(-2, -1) * dimensions ** -0.5
        affinities = affinities.masked_fill(self.tril[:time_pos, :time_pos] == 0, float('-inf'))
        affinities = F.softmax(affinities, dim=-1)
        affinities = self.dropout(affinities)

        value = self.value(inputs)
        outputs = affinities @ value
        
        return outputs


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads: int, n_embedding_dimensions: int, attention_head_dimensions: int) -> None:
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(n_embedding_dimensions, attention_head_dimensions) for _ in range(n_heads)] 
        )
        self.projection = nn.Linear(attention_head_dimensions * n_heads, attention_head_dimensions * n_heads)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, inputs) -> torch.Tensor:
        outputs = torch.cat([head(inputs) for head in self.heads], dim=-1)
        outputs = self.projection(outputs)
        outputs = self.dropout(outputs)

        return outputs

In [None]:
class FeedForward(nn.Module):
    def __init__(self, attention_head_dimensions: int) -> None:
        super().__init__()
        self.forward_net = nn.Sequential(
            nn.Linear(attention_head_dimensions, 4 * attention_head_dimensions),
            nn.ReLU(),
            # Projection layer
            nn.Linear(4 * attention_head_dimensions, attention_head_dimensions),
            nn.Dropout(dropout_rate)
        )

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.forward_net(inputs)

In [None]:
class TransformerBlock(nn.Module):
    """A transformer block for communication -> computation"""

    def __init__(self, attention_head_dimensions: int, n_attention_heads: int, n_embedding_dimensions: int) -> None:
        super().__init__()

        head_size = attention_head_dimensions // n_attention_heads
        self.self_attention_heads = MultiHeadAttention(
            n_attention_heads, n_embedding_dimensions, head_size
        )
        self.forward_net = FeedForward(attention_head_dimensions)
        self.layer_norm1 = nn.LayerNorm(attention_head_dimensions)
        self.layer_norm2 = nn.LayerNorm(attention_head_dimensions)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # The addition is for residual connections
        outputs = outputs + self.self_attention_heads(self.layer_norm1(inputs))
        outputs = outputs + self.forward_net(self.layer_norm2(outputs))
        return outputs

In [None]:
class BigramLanguageModelWithAttention(nn.Module):
    def __init__(
            self,
            vocab_size: int, 
            context_size: int,
            n_embedding_dimensions: int = 384,
            n_attention_heads: int = 6,
            n_transformer_blocks: int = 6
        ):
        super().__init__()
        # Embedding table
        self.embedding_table = nn.Embedding(vocab_size, n_embedding_dimensions)
        # Position as embedding. Transformers don't have a means to know which element of a sequence
        # they are working with. This is learned in this table
        self.position_embedding_table = nn.Embedding(context_size, n_embedding_dimensions)
        self.transformer_blocks = nn.Sequential(
            *[
                TransformerBlock(ATTENTION_HEAD_DIMENSIONS, n_attention_heads, n_embedding_dimensions)
                for _ in range(n_transformer_blocks)
            ]
        )
        self.layer_norm = nn.LayerNorm(ATTENTION_HEAD_DIMENSIONS)
 
        self.self_attention_heads = MultiHeadAttention(
            n_attention_heads, n_embedding_dimensions, ATTENTION_HEAD_DIMENSIONS // n_attention_heads
        )
        self.forward_net = FeedForward(ATTENTION_HEAD_DIMENSIONS)
        # Linear layer to go from embeddings to logits
        self.lm_head = nn.Linear(ATTENTION_HEAD_DIMENSIONS, vocab_size)

    def forward(
            self,
            inputs: torch.Tensor,
            targets: Optional[torch.Tensor] = None
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        
        batch_size, time_pos = inputs.shape
        token_embeddings = self.embedding_table(inputs)
        position_embeddings = self.position_embedding_table(torch.arange(time_pos, device=device))
        head_inputs = token_embeddings + position_embeddings
        head_inputs = self.self_attention_heads(head_inputs)
        head_inputs = self.forward_net(head_inputs)
        logits = self.lm_head(head_inputs)  # (batch size, time, embedding_dimensions)

        loss = None

        if targets is not None:
            # Reshape as Pytorch expects
            b, t, c = logits.shape
            logits = logits.view(b*t, c)
            targets = targets.view(b*t)
            loss = F.cross_entropy(logits, targets)

        # The return value is (batch_size, context_size, vocab_size)
        return logits, loss
    
    def generate(self, inputs: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        """
        Generates up to max_new_tokens of text by predicting what comes after the inputs
        """
        prediction_appended = inputs
        for _ in range(max_new_tokens):
            inputs_cropped = prediction_appended[:, -context_size:]
            logits, _ = self(inputs_cropped)
            # We are interested only in the last time step
            logits = logits[:, -1, :]
            probabilities = F.softmax(logits, dim=-1)
            # Sample
            next_token = torch.multinomial(probabilities, num_samples=1)
            # Append to the current input
            prediction_appended = torch.cat((prediction_appended, next_token), dim=1)

        return prediction_appended

In [None]:
bigram_language_model_with_attention = BigramLanguageModelWithAttention(
    vocab_size, context_size
).to(device)
optimizer = torch.optim.AdamW(bigram_language_model_with_attention.parameters(), lr=learning_rate)

max_iterations = 10000
eval_iterations = 1000

for iteration in tqdm_notebook(range(max_iterations)):

    if iteration % eval_iterations == 0:
        evaluated_loss = estimate_loss(bigram_language_model_with_attention, eval_iterations)
        display(f'Step {iteration}: Train loss: {evaluated_loss[SplitType.train]}, Validation loss: {evaluated_loss[SplitType.validation]}')

    inputs, targets = get_batch(SplitType.train)

    logits, loss = bigram_language_model_with_attention(inputs, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
generated_text = generate(bigram_language_model_with_attention, max_new_tokens=1000)

print(generated_text)