# Coding Transformers from scratch (almost)

In [None]:
# IMPORTS
import torch,torch.nn as nn,torch.nn.functional as F, torch.optim as optim, math


from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import TokenTexth5

## Building the model...

### Self-attention Layer

In [None]:
class MaskedSelfAttention(nn.Module):

    def __init__(self, embed_dim, attn_length, n_heads, dropout=0.1,device='cpu'): # Define the sub-layers here
        super().__init__()
        self.attn_length = attn_length
        self.n_heads = n_heads


    def forward(self, x): # Apply the different layers
        """
            Args : 
            x : (B,T,D)

            Return : Tensor (B,T,D)
        """
        pass

        


#### Test attention

In [None]:
attn_test =  MaskedSelfAttention(embed_dim=6,attn_length=10, n_heads = 2)

in_tens = torch.randn((1,8,6))

out_tens = attn_test(in_tens)# (B,T,D)

assert out_tens.shape==in_tens.shape

### MLP layer, easy part

In [None]:
class MLP(nn.Module):
    """
        Simple feedforward with two layers. Blows up by a factor of 4 the embed_dim
    """

    def __init__(self,embed_dim, mlp_ratio=4., device='cpu'):
        super().__init__()


    def forward(self,x):
        """
            Args :
            x : (*,embed_dim)

            Returns : (*,embed_dim)
        """
        pass

In [None]:
testmlp = MLP(embed_dim=5)

test_input = torch.randn((2,10,5))

print('input shape: ', test_input.shape)
print('outshape : ', testmlp(test_input).shape)
# print('output : ', testmlp(test_input))

### 'Decoder' Transformer Block

In [None]:
class TransformerBlock(nn.Module):

    def __init__(self,embed_dim,attn_length, n_heads, dropout=0.1, device='cpu'):
        super().__init__()

    def forward(self,x):
        """
        Args :
        x : (B,T,D) tensor

        Returns : (B,T,D) tensor
        """
        pass

#### Test block

In [None]:
block_test =  TransformerBlock(embed_dim=6,attn_length=10, n_heads = 2)

in_tens = torch.randn((2,8,6))

out_tens = block_test(in_tens)# (B,T,D)

assert out_tens.shape==in_tens.shape

### Finally, the text transformer

In [None]:
class GPT(nn.Module):
    def __init__(self, n_layers, vocab_size,embed_dim,attn_length,n_heads, dropout=0.1,device='cpu'):
        super().__init__()
        self.attn_length = attn_length

    
    def forward(self, x):
        """
            Args : 
            x : (B,T) of longs (tokens)

            Returns : (B,T,vocab_size) of logits
        """
        pass
    

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        """
            Take a conditioning sequence of indices idx (LongTensor of shape (B,T)) and complete
            the sequence max_new_tokens times, feeding the predictions back into the model each time.
            Use with model in inference mode (apply model.eval() first)

            Args :
            idx : (B,T) tensor of context tokens. Mostly, it will be B=1 but can do in parallel also
            max_new_tokens : number of tokens to generate on top of the conditioning sequence
            temperature : softmax temperature (lower -> more conservative sampling)
            do_sample : if True, use multinomial sampling. Otherwise use greedy decoding
            top_k : if set to int > 0, only sample from the top k most probable logits

            Returns :
            (B,T) LongTensor of generated token indices. Must still be decoded by tokenizer.
        """

        for _ in range(max_new_tokens):
            idx_next = self.generate_next_token(idx,temperature=temperature,do_sample=do_sample,top_k=top_k)

            idx = torch.cat((idx, idx_next), dim=1)

        return idx
    

    @torch.no_grad()
    def generate_next_token(self,idx,temperature=1.0, do_sample=False, top_k=None):
        """
            Take a conditioning sequence of indices idx (LongTensor of shape (B,T)) and return
            the next predicted token.
            Use with model in inference mode (apply model.eval() first)

            Args :
            idx : (B,T) tensor of context tokens. Mostly, it will be B=1 but can do in parallel also
            temperature : softmax temperature (lower -> more conservative sampling)
            do_sample : if True, use multinomial sampling. Otherwise use greedy decoding
            top_k : if set to int > 0, only sample from the top k most probable logits

            Returns :
            next predicted token, Long
        """
        idx_cond = idx if idx.shape[1] <= self.attn_length else idx[:, -self.attn_length:]
        # forward the model to get the logits for the index in the sequence
        logits = self.forward(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)

        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
            
        # Return sampled index
        return idx_next

# Training, let's test it!

In [None]:
# Training pre-requisites :

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("en_tokenizer")

# Transformer parameters :
attn_length = 64
n_layers = 4
embed_dim = 64
n_heads = 4
vocab_size = tokenizer.vocab_size
device='cpu'

# Add parameters
myGPT = GPT(n_layers=n_layers,vocab_size=vocab_size,embed_dim=embed_dim,attn_length=attn_length, n_heads=n_heads,device=device)
# Dataset and dataloader
dataset = TokenTexth5("test_text.h5", attn_length=attn_length)
dataloader = DataLoader(dataset,batch_size=32)
# Optimizers :
optimus = optim.AdamW(myGPT.parameters(),lr=2e-3)

# # Test all is well :
print("Example raw : \n", dataset[0][0][:5])
print("Example raw answer : \n", dataset[0][1][:5])
# All is ready !

In [None]:
from tqdm import tqdm
# Training loop :
epochs = 100
running_loss = []
for ep in tqdm(range(epochs)) :
    for toks,tru_toks in dataloader:
        toks=toks.to(device)
        tru_toks=tru_toks.to(device)

        logits = myGPT(toks) # (B,T,v_size)
        logits = logits.transpose(1,2) # (B,v_size,T) required by cross_entropy of pytorch

        loss = F.cross_entropy(logits, tru_toks) # Use pytorch to prevent problems with infinities of log(0)

        loss.backward() # backprop

        optimus.step() # Adjust params
        optimus.zero_grad()
        running_loss.append(loss.item())
    if(ep%10==0):
        print(f'ep {ep}, loss : {sum(running_loss)/(len(running_loss))}')
        

In [None]:
## Try the generation :
myGPT.to('cpu')
conditioning = tokenizer.encode('Hello')

initial = torch.tensor(conditioning)[None] # (1,T,)

output = myGPT.generate(initial,max_new_tokens=54, temperature=0.3)[0] # (only one batch, remove it)
print('OUTPUT : ')
print(tokenizer.decode(output))


In [None]:
## Try the generation :
myGPT.to('cpu')
conditioning = tokenizer.encode('Thanks')

initial = torch.tensor(conditioning)[None] # (1,T,)

output = myGPT.generate(initial,max_new_tokens=25, temperature=0.1)[0] # (only one batch, remove it)
print('OUTPUT : ')
print(tokenizer.decode(output))
