# In this notebook, we gain an intuition of training transformers and try to see it's pros and cons 

In [1]:
import tiktoken
import torch
import torch.nn as nn
import numpy as np
import time

# Create the models

import math
import torch.nn.functional as F


### Create a Next character level prediction dataset

In [2]:
class NextCharDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_path , context_length):
        self.dataset_path = dataset_path
        self.context_length = context_length
        
        with open(self.dataset_path,"r") as f:
            self.raw_text = f.read()

        self.tokens = list(self.raw_text)
        
        # Remap the tokens
        
        remapping = np.arange(len(set(self.tokens))).tolist()
        self.mapping = dict(zip(list(set(self.tokens)),remapping))
        self.inverse_mapping = dict(zip(remapping,list(set(self.tokens))))
        
        self.vocab_size = len(set(self.tokens))
        
        

    def __len__(self):
        return len(self.tokens)-self.context_length

    def __getitem__(self, idx):

        x_tokens = self.tokens[idx:idx+self.context_length+1]
        x_maped = [self.mapping[i] for i in x_tokens]
        x = x_maped[:-1]
        y = x_maped[1:]

        return torch.tensor(x,dtype=torch.long), torch.tensor(y,dtype=torch.long)

In [None]:
txtx= "toto titi"
list(txtx)

### Create a next Token prediction dataset

In [None]:
class NextTokenDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_path , context_length):
        self.dataset_path = dataset_path
        self.context_length = context_length
        
        with open(self.dataset_path,"r") as f:
            self.raw_text = f.read()
        self.tokeniser = tiktoken.get_encoding("o200k_base")
        self.tokens = self.tokeniser.encode(self.raw_text)
        
        # Remap the tokens
        
        remapping = np.arange(len(set(self.tokens))).tolist()
        self.mapping = dict(zip(list(set(self.tokens)),remapping))
        self.inverse_mapping = dict(zip(remapping,list(set(self.tokens))))
        
        self.vocab_size = len(set(self.tokens))
        
        

    def __len__(self):
        return len(self.tokens)-self.context_length

    def __getitem__(self, idx):

        x_tokens = self.tokens[idx:idx+self.context_length+1]
        x_maped = [self.mapping[i] for i in x_tokens]
        x = x_maped[:-1]
        y = x_maped[1:]

        return torch.tensor(x,dtype=torch.long), torch.tensor(y,dtype=torch.long)

In [None]:
def create_datasets(dataset_class, dataset_path, context_length):
    full_dataset = dataset_class(dataset_path, context_length)
    train_dataset ,  test_dataset = torch.utils.data.random_split(full_dataset, [0.99,0.01])
    
    return train_dataset, test_dataset



In [None]:
train_dataset, test_dataset = create_datasets(NextCharDataset,"hymns.txt", 100)
batch_size = 100
shuffle = True
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size, shuffle=shuffle)

In [None]:

class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self,ntoken, ninp, nhid, nlayers, dropout=0.0):
        super(RNNModel, self).__init__()
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout,batch_first=True)
        self.output_layer = nn.Linear(nhid, ntoken)
        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.output_layer.bias)
        nn.init.uniform_(self.output_layer.weight, -initrange, initrange)

    def forward(self, x):
        bs,context = x.size()
        hidden = self.init_hidden(bs)
        emb = self.drop(self.input_emb(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.output_layer(output)

        decoded = F.log_softmax(decoded, dim=1)

        return decoded

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))


### Positionnal Embedding

In [None]:
# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

### Build your transformer model

In [None]:

class TransformerModel(nn.Transformer):
    """Container module with an encoder, a recurrent or transformer module, and a decoder."""

    def __init__(self, 
                 ntoken,
                 ninp,
                 nhead,
                 nhid,
                 nlayers,
                 dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers)
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)

        self.input_emb = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.output_layer = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.output_layer.bias)
        nn.init.uniform_(self.output_layer.weight, -initrange, initrange)

    def forward(self, src, has_mask=True):

        device = src.device

        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask


        src = self.input_emb(src)
        src = self.pos_encoder(src)
        output = self.encoder(src, mask=self.src_mask)
        output = self.output_layer(output)
        return F.log_softmax(output, dim=-1)

In [3]:
torch.cuda.is_available()

True

In [None]:
ntoken = train_dataset.dataset.vocab_size
nembd = 64
nhead = 4
nlayer = 5
feedforward_nlayer = 50
lstm_nlayer = 5
lstm_nhid = 50

transformer_decoder = TransformerModel(ntoken,nembd,nhead ,feedforward_nlayer,nlayer )
out = transformer_decoder(torch.tensor([[10,0,5,3,6,20,5]]))
out.shape 
lstm_model = RNNModel(ntoken, nembd, lstm_nhid, lstm_nlayer, dropout=0.50).to("cuda")

###  Write the training loop for the transformer and the recurrent neural network

In [None]:
model = transformer_decoder.to("cuda")
# model = lstm_model.to("cuda")
# init optimizer

weight_decay = 0.01
epochs = 150


# training loop
best_loss = None
step = 0

#for name, model , lr in zip([ "lstm"],[lstm_model],[5e-2]):
for name, model , lr in zip([ "transformer","lstm"],[transformer_decoder, lstm_model],[5e-3,5e-2]):
    
    final_lr = 5e-5
    gamma = (final_lr / lr) ** (2.0 / epochs)
    

#TO DO 

    
    print(f"--"*89)
    print(f"Model {name} | lr {lr}")
    print("--" * 89)
    
    for epoch in range(epochs):
        t0 = time.time()
        #TO DO


        t1 = time.time()
        print(f"Epoch {epoch + 1} | Loss {loss.item():.4f} | Step time {(t1 - t0) * 1000:.2f}ms | Current LR {scheduler.get_last_lr()[0]:.6f}")

        #print(f"Epoch {epoch} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms")
    print(f"Model {name} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms")

### Write the generation code

In [None]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, context, temperature=1.0, do_sample=False, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    block_size = context
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        # forward the model to get the logits for the index in the sequence
        logits= model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature

        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

### Generate

In [None]:
prompt = """ vive """

prompt_tokens = train_dataset.dataset.tokeniser.encode(prompt)
tokens = [train_dataset.dataset.mapping[i] for i in prompt_tokens]
tokens = torch.tensor([tokens],dtype=torch.long)

for name, model in zip(["transformer", "lstm"],[transformer_decoder, lstm_model]):
    generated = generate(model,tokens.to("cuda"),1000,context=10)
    response = [train_dataset.dataset.inverse_mapping[i] for i in generated.cpu().numpy().squeeze().tolist()]
    response = train_dataset.dataset.tokeniser.decode(response)
    print("=="*50)
    print(f"Name : {name}")
    print("=="*50)
    print(response)
                      
                      
                      