In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
from torch import nn

In [None]:
dataset = load_dataset('notaphoenix/shakespeare_dataset')

# For convenience we use the training data as a list of str:
corpus = [elt['text'] for elt in dataset['training']]
#tokenized_corpus = [list(word) for word in corpus] # one-liner (optional)

tokenized_corpus = []
for elt in corpus:
    tokenized_corpus.append(list(elt))

tokenized_corpus[0]

#vocab = {char: i for i, char in enumerate(set("".join(corpus)))} # one-liner (optional)

vocab = {}
for elt in tokenized_corpus: # iterate over corpus
    for char in elt: # iterate over characters of elt
        if char not in vocab: # add to dictionary
            vocab[char] = len(vocab) # index of character is current length of vocab !

vocab.items()

def get_pairs(tokenized_corpus):
    """Get the frequency of adjacent pairs in the words."""
    pairs = {}
    for word in tokenized_corpus:
        for i in range(len(word) - 1):
            if (word[i], word[i + 1]) not in pairs:
                pairs[(word[i], word[i + 1])] = 1
            else:
                pairs[(word[i], word[i + 1])] += 1
    return pairs

def merge_pair(pair, tokenized_corpus):
    """Merge the most frequent pair in all tokenized_corpus."""
    new_words = []
    for word in tokenized_corpus:
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
                new_word.append(word[i] + word[i + 1])  # Merge pair ("+" is str concatenation)
                i += 2  # Skip next character since it's merged
            else:
                new_word.append(word[i])
                i += 1
        new_words.append(new_word)
    return new_words

def byte_pair_encoding(corpus, num_merges: int = 10):
    """Perform BPE on a given corpus."""
    tokenized_corpus = [list(word) for word in corpus]  # Start with character tokens
    vocab = {char: i for i, char in enumerate(set("".join(corpus)))}  # Initial vocab
    
    for _ in range(num_merges):
        pairs = get_pairs(tokenized_corpus)
        if not pairs:
            break
        best_pair = max(pairs, key=pairs.get)
        tokenized_corpus = merge_pair(best_pair, tokenized_corpus)
        new_token = best_pair[0] + best_pair[1]
        vocab[new_token] = len(vocab)  # Assign new token an index
    
    return vocab, tokenized_corpus

vocab, tokenized_corpus = byte_pair_encoding(corpus, num_merges=500)

def tokenize_to_str_list(s: str, vocab):
    """Tokenize a given string based on the trained BPE vocabulary."""
    tokens = list(s)  # Start with character tokens
    
    while True:
        pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
        valid_pairs = [pair for pair in pairs if pair[0] + pair[1] in vocab]
        
        if not valid_pairs:
            break
        
        best_pair = max(valid_pairs, key=lambda p: vocab.get(p[0] + p[1], float('-inf')))
        merged_token = best_pair[0] + best_pair[1]
        
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == best_pair:
                new_tokens.append(merged_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        
        tokens = new_tokens
    
    return tokens    

tokenize_to_str_list(dataset['test'][0]['text'], vocab)

def tokenize(s: str, vocab):
    return [vocab[elt] for elt in tokenize_to_str_list(s, vocab)]

tokenize(dataset['training'][0]['text'], vocab)

## 2. Dataset preprocessing


We now Tokenize the full training and test splits of the datasets, to obtain test_ids and train_ids lists. We will work on these now
A good way to proceed is to the '.map' method see https://huggingface.co/docs/datasets/en/process#map

In [None]:
train_ids = dataset['training'].map(lambda s: {'ids': tokenize(s['text'], vocab)})['ids']
test_ids = dataset['test'].map(lambda s: {'ids': tokenize(s['text'], vocab)})['ids']
train_ids[0]

### 2.2 We will later on work on **batches** of data as is always the case when training deep learning models.
#### The problem here is that we cannot just stack the 'ids' as they have variable length:

In [None]:
dataloader = DataLoader(train_ids, batch_size=2, shuffle=True) # produces batch by 'collating' individual samples
for elt in dataloader: # will raise an error !
    print(elt) 

#### To solve this issue, we add a 'padding token' to the vocabulary

In [None]:
if '<PAD>' not in vocab:
    vocab['<PAD>'] = len(vocab)
    pad_id = len(vocab)
else:
    pad_id = len(vocab) - 1

### 2.3 implement a 'collate_batch_fn' which pads the sequences found as input using the pad_id, and returns the non-padding mask
For instance if input contains
[[1, 2, 3], [5, 6, 7, 8]]
you should return a torch tensor of type long [[1, 2, 3, pad_id], [5, 6, 7, 8]] as 'ids' and
a tensor of type long [[1, 1, 1, 0], [1, 1, 1, 1]] as mask


In [None]:
def collate_fn(batch):
    """
    here batch is a list of list of ints
    You should return a dictionary containing two keys;
    "ids" which contain the padded list of elements in the batch
    It should be a torch Tensor of size (len(batch), max([len(elt) for elt in batch]))
    "mask" which contains zeros where there is padding, same size as 'ids'
    """
    return # TODO

### Verify that this now works:

In [None]:
dataloader = DataLoader(train_ids, batch_size=2, shuffle=True, collate_fn=collate_fn) # produces batch by 'collating' individual samples

for elt in dataloader: 
    print(elt) 

### Ok now we have a dataloader which yields batches of padded tokenized texts. We are ready to start implementing the transformers architecture. We start with the attention layer.

## 3. Attention layer
Remember the attention layer takes as input a (batched) sequence of hidden states (shape (B, T, H)), and returns a tensor of same dimension exactly containing the attention values.
We will start with an example. 
Let's assume we have this input x to the layer and a head_size of 4 (head_size is dimension of query/keys/values)

In [None]:
x = torch.randn((2, 7, 16)) # batch  of 2 sequences of 7 hidden states of dimension 16 (B, T, H) = (2, 7, 16)
B, T, H = 2, 7, 16
head_size = 4

### 3.1 Declare key_matrix, query_matrix and values_matrix pytorch layers to produce, from this x, the keys, queries and values. Apply these to get (B, T, head_size) tensors of K, Q, V
### We want the values to have shape (B, T, head_size)
hint: in torch, matrices are represented as bias-less Linear modules https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

In [None]:
key_matrix = ...
query_matrix = ...
value_matrix = ...

k = key_matrix(x)
q = query_matrix(x)
v = value_matrix(x)
k.size(), q.size(), v.size()

### 3.2 Q and K are (B, T, head_size) matrices. To compute the attention scores, we just need to compute the batch matrix multiplication of Q and K.transpose(1, 2): this op will do: (B, T, head_size) x (B, head_size, T) - (B, T, T): one (unnormalized) attention score for each pair of tokens. Implement this using torch.matmul.

In [None]:
attention_scores = ...
attention_scores # check the dimensions you obtain

### 3.3 Now we need to apply the causal masking so that past tokens do not attend future tokens, but future tokens do attend past tokens. To do this, we create a triangular inferior mask for each element of the batch:

In [None]:
mask = ...
mask

### Use this mask to set to -float('inf') the attention_scores where the mask if 0. 

In [None]:
...

### 3.4 Apply the softmax to normalize row-wise the masked attention scores. Why did we set attention scores to -float('inf') outside of the mask ?

In [None]:
...

### 3.5 With another matrix multiplication between the attention_scores (B, T, T) and the values (B, T, head_size), obtain the values as as (B, T, head_size) tensor.

In [None]:
...

### Now gather it all to create an Attention nn.Module object implementing this attention operation.

In [None]:
class Attention(nn.Module):
    """ one head of self-attention """
    def __init__(self, hidden_state_dim, head_size):
        super().__init__()
        self.key = ...
        self.query = ...
        self.value = ...

    def forward(self, x):
        return # TODO
    
attention = Attention(hidden_state_dim=16, head_size = 4) # check it works 

### 3.6 Now create a multi-head attention layer that, given some x, computes n_head attention in parallel, each with a head size of (hidden_state_dim / n_heads). Each head produces a (B, T, head_size) set of values. The multi head attention layer should concatenate them into a (B, T, head_size*n_heads=hidden_dim) and apply a final 'projection' layer which is just a linear map.
You can assume that n_heads divide hidden_state_dim.

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, hidden_state_dim, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([???])
        self.proj = ???

    def forward(self, x):
        return # TODO
    

mha = MultiHeadAttention(16, 4)
mha(x).size()

## 4. Finshing the transformer block

### 4.1: Feed-forward layer (also called MLP layer): Implement the feed forward layer of the transformer block.
Reminder: this layer applies two linear layers (embedding_size -> mlp_factor * embedding_size -> embedding size where mlp_factor controls the size of the intermediate layer with a ReLU activation in between.

In [None]:
class FeedForward(nn.Module):
    def __init__(self, embedding_size, mlp_factor=4):
        super().__init__()
        ???

    def forward(self, x):
        ???

### 4.2 Now implement the full transformer block.
For normalization, you can just use:
`self.norm = nn.LayerNorm(embedding_size)` and apply it, as any other nn.module, to any input whose last dimension is embedding_size

Don't forget the residual connections:
- the input x should be added to the output of 1st layer norm + attention to obtain y
- y should be added to the output of 2st layer norm + mlp

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_state_dim, n_heads, mlp_factor):
        super().__init__()
        ???

    def forward(self, x):
        ???

## 5. Finish implementing the transformer

In [None]:
class GPT(nn.Module):

    def __init__(self, vocab_size, hidden_state_dim, n_heads, mlp_factor=4):
        super().__init__()
        ???

    def forward(self, idx, targets=None):
        ???

## 6. Still things to be done:
### 6.1 Implement a 'generate' method which, given a trained Transformer object and an initial list of token ids, generate the following tokens, according to the language model
### 6.2 How would you implement the loss function ?
### 6.3 Implement positional encodings
### 6.4 What else ? 