# Coding Transformers from scratch (almost)

In [1]:
# IMPORTS
import torch,torch.nn as nn,torch.nn.functional as F, torch.optim as optim, math


from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import TokenTexth5

## Building the model...

### Self-attention Layer

In [2]:
class MaskedSelfAttention(nn.Module):

    def __init__(self, embed_dim, attn_length, n_heads, dropout=0.1,device='cpu'): # Define the sub-layers here
        super().__init__()
        self.attn_length = attn_length
        self.n_heads = n_heads

        self.q_maker = nn.Linear(embed_dim, n_heads*embed_dim, device=device)
        self.k_maker = nn.Linear(embed_dim, n_heads*embed_dim, device=device)
        self.v_maker = nn.Linear(embed_dim, n_heads*embed_dim, device=device)

        # mask
        self.register_buffer('mask',torch.tril(torch.ones((attn_length,attn_length),device=device))) # (attn_length,attn_length) upper triangular mask

        self.dropout = nn.Dropout(dropout)

        self.project_out = nn.Linear(n_heads*embed_dim,embed_dim,device=device)


    def forward(self, x): # Apply the different layers
        """
            Args : 
            x : (B,T,D)

            Return : Tensor (B,T,D)
        """
        B,T,D = x.shape

        assert T<=self.attn_length, "Input is too long, limit size to attention lenght."
        # Generate Q,K,V :

        Q = self.q_maker(x).reshape(B,T,self.n_heads,D) # (B,T,n_heads,D)
        K = self.k_maker(x).reshape(B,T,self.n_heads,D)  # (B,T,n_heads,D)
        V = self.v_maker(x).reshape(B,T,self.n_heads,D)  # (B,T,n_heads,D)

        Q = Q.permute(0,2,1,3) # (B,n_heads,T,D)
        K = K.permute(0,2,1,3) # (B,n_heads,T,D)
        V = V.permute(0,2,1,3) # (B,n_heads,T,D)
        
        
        # Compatibility matrix :
        C = Q @ K.transpose(-1,-2) * 1./(math.sqrt(D)) # (B,T,T)
        C = torch.masked_fill(C,self.mask[:T,:T]==0, float('-inf')) # C : (B,T,T), mask : (T, T)

        # Apply softmax
        C = F.softmax(C,dim=-1)

        C=self.dropout(C)

        attention = C @ V # (B,n_heads,T,D)

        attention = attention.permute(0,2,1,3).reshape(B,T,self.n_heads*D)

        return self.project_out(attention) # (B,T,D)

        


#### Test attention

In [3]:
attn_test =  MaskedSelfAttention(embed_dim=6,attn_length=10, n_heads = 2)

in_tens = torch.randn((1,8,6))

out_tens = attn_test(in_tens)# (B,T,D)

assert out_tens.shape==in_tens.shape

### MLP layer, easy part

In [3]:
class MLP(nn.Module):
    """
        Simple feedforward with two layers. Blows up by a factor of 4 the embed_dim
    """

    def __init__(self,embed_dim, mlp_ratio=4., device='cpu'):
        super().__init__()

        self.ln1 = nn.Linear(embed_dim, int(mlp_ratio*embed_dim),device=device)
        self.ln2 = nn.Linear(int(mlp_ratio*embed_dim),embed_dim,device=device)

        self.nonlin = nn.GELU()
    

    def forward(self,x):
        """
            Args :
            x : (*,embed_dim)

            Returns : (*,embed_dim)
        """
        x = self.ln2(self.nonlin(self.ln1(x)))

        return x

In [4]:
testmlp = MLP(embed_dim=5)

test_input = torch.randn((2,10,5))

print('input shape: ', test_input.shape)
print('outshape : ', testmlp(test_input).shape)
# print('output : ', testmlp(test_input))

input shape:  torch.Size([2, 10, 5])
outshape :  torch.Size([2, 10, 5])


### 'Decoder' Transformer Block

In [5]:
class TransformerBlock(nn.Module):

    def __init__(self,embed_dim,attn_length, n_heads, dropout=0.1, device='cpu'):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, device=device)
        self.attn = MaskedSelfAttention(embed_dim,attn_length, n_heads, dropout, device=device)
        self.attn_dropout = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(embed_dim, device=device)
        self.mlp= MLP(embed_dim, device=device)
        self.mlp_dropout = nn.Dropout(dropout)


    def forward(self,x):
        """
        Args :
        x : (B,T,D) tensor

        Returns : (B,T,D) tensor
        """
        x = x+self.attn_dropout(self.attn(self.norm1(x)))
        x = x+self.mlp_dropout(self.mlp(self.norm2(x)))

        return x

#### Test block

In [6]:
block_test =  TransformerBlock(embed_dim=6,attn_length=10, n_heads = 2)

in_tens = torch.randn((2,8,6))

out_tens = block_test(in_tens)# (B,T,D)

assert out_tens.shape==in_tens.shape

### Finally, the text transformer

In [7]:
class GPT(nn.Module):
    def __init__(self, n_layers, vocab_size,embed_dim,attn_length,n_heads, dropout=0.1,device='cpu'):
        super().__init__()
        self.attn_length = attn_length
        self.tok_embedder = nn.Embedding(vocab_size,embed_dim, device=device)
        self.pos_embedder = nn.Embedding(attn_length,embed_dim, device=device)

        self.blocks = nn.ModuleList([TransformerBlock(embed_dim,attn_length,n_heads,dropout,device=device) for _ in range(n_layers)])

        self.ln_final = nn.LayerNorm(embed_dim,device=device)

        self.project_final = nn.Linear(embed_dim,vocab_size,bias=False,device=device)
    
    def forward(self, x):
        """
            Args : 
            x : (B,T) of longs (tokens)

            Returns : (B,T,vocab_size) of logits
        """
        B,T = x.shape

        positions = torch.arange(0,T,1, device=x.device) # positions
        pos_embedded = self.pos_embedder(positions)
        tok_embedded = self.tok_embedder(x) # (B,T,D)

        x = pos_embedded+tok_embedded
        
        for block in self.blocks :
            x = block(x) # (B,T,D)
        
        x = self.ln_final(x)
        x = self.project_final(x)

        return x
    

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        """
            Take a conditioning sequence of indices idx (LongTensor of shape (B,T)) and complete
            the sequence max_new_tokens times, feeding the predictions back into the model each time.
            Use with model in inference mode (apply model.eval() first)

            Args :
            idx : (B,T) tensor of context tokens. Mostly, it will be B=1 but can do in parallel also
            max_new_tokens : number of tokens to generate on top of the conditioning sequence
            temperature : softmax temperature (lower -> more conservative sampling)
            do_sample : if True, use multinomial sampling. Otherwise use greedy decoding
            top_k : if set to int > 0, only sample from the top k most probable logits

            Returns :
            (B,T) LongTensor of generated token indices. Must still be decoded by tokenizer.
        """

        for _ in range(max_new_tokens):
            idx_next = self.generate_next_token(idx,temperature=temperature,do_sample=do_sample,top_k=top_k)

            idx = torch.cat((idx, idx_next), dim=1)

        return idx
    

    @torch.no_grad()
    def generate_next_token(self,idx,temperature=1.0, do_sample=False, top_k=None):
        """
            Take a conditioning sequence of indices idx (LongTensor of shape (B,T)) and return
            the next predicted token.
            Use with model in inference mode (apply model.eval() first)

            Args :
            idx : (B,T) tensor of context tokens. Mostly, it will be B=1 but can do in parallel also
            temperature : softmax temperature (lower -> more conservative sampling)
            do_sample : if True, use multinomial sampling. Otherwise use greedy decoding
            top_k : if set to int > 0, only sample from the top k most probable logits

            Returns :
            next predicted token, Long
        """
        idx_cond = idx if idx.shape[1] <= self.attn_length else idx[:, -self.attn_length:]
        # forward the model to get the logits for the index in the sequence
        logits = self.forward(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)

        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
            
        # Return sampled index
        return idx_next

# Training, let's test it!

In [8]:
# Training pre-requisites :

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("en_tokenizer")

# Transformer parameters :
attn_length = 64
n_layers = 4
embed_dim = 64
n_heads = 4
vocab_size = tokenizer.vocab_size
device='cpu'

# Add parameters
myGPT = GPT(n_layers=n_layers,vocab_size=vocab_size,embed_dim=embed_dim,attn_length=attn_length, n_heads=n_heads,device=device)
# Dataset and dataloader
dataset = TokenTexth5("test_text.h5", attn_length=attn_length)
dataloader = DataLoader(dataset,batch_size=32)
# Optimizers :
optimus = optim.AdamW(myGPT.parameters(),lr=2e-3)

# # Test all is well :
print("Example raw : \n", dataset[0][0][:5])
print("Example raw answer : \n", dataset[0][1][:5])
# All is ready !

Dataset contains 0.00M tokens, resulting in 44 examples.
Example raw : 
 tensor([6279,   12,  467, 1622,  305])
Example raw answer : 
 tensor([  12,  467, 1622,  305,  512])


In [9]:
from tqdm import tqdm
# Training loop :
epochs = 100
running_loss = []
for ep in tqdm(range(epochs)) :
    for toks,tru_toks in dataloader:
        toks=toks.to(device)
        tru_toks=tru_toks.to(device)

        logits = myGPT(toks) # (B,T,v_size)
        logits = logits.transpose(1,2) # (B,v_size,T) required by cross_entropy of pytorch

        loss = F.cross_entropy(logits, tru_toks) # Use pytorch to prevent problems with infinities of log(0)

        loss.backward() # backprop

        optimus.step() # Adjust params
        optimus.zero_grad()
        running_loss.append(loss.item())
    if(ep%10==0):
        print(f'ep {ep}, loss : {sum(running_loss)/(len(running_loss))}')
        

  0%|          | 0/100 [00:00<?, ?it/s]

  1%|          | 1/100 [00:01<01:50,  1.11s/it]

ep 0, loss : 10.72431993484497


 11%|█         | 11/100 [00:12<01:39,  1.12s/it]

ep 10, loss : 7.805430130525068


 21%|██        | 21/100 [00:23<01:26,  1.09s/it]

ep 20, loss : 5.3846186229160855


 31%|███       | 31/100 [00:33<01:12,  1.05s/it]

ep 30, loss : 3.842040276575473


 41%|████      | 41/100 [00:44<01:04,  1.09s/it]

ep 40, loss : 2.9409589802891745


 51%|█████     | 51/100 [00:55<00:53,  1.09s/it]

ep 50, loss : 2.377181645738436


 61%|██████    | 61/100 [01:06<00:42,  1.10s/it]

ep 60, loss : 1.994498003029921


 71%|███████   | 71/100 [01:17<00:31,  1.07s/it]

ep 70, loss : 1.7182017877356897


 81%|████████  | 81/100 [01:27<00:20,  1.06s/it]

ep 80, loss : 1.5093311000569367


 91%|█████████ | 91/100 [01:38<00:09,  1.05s/it]

ep 90, loss : 1.3459279157610222


100%|██████████| 100/100 [01:47<00:00,  1.08s/it]


In [10]:
## Try the generation :
myGPT.to('cpu')
conditioning = tokenizer.encode('Hello')

initial = torch.tensor(conditioning)[None] # (1,T,)

output = myGPT.generate(initial,max_new_tokens=54, temperature=0.3)[0] # (only one batch, remove it)
print('OUTPUT : ')
print(tokenizer.decode(output))


OUTPUT : 
Hello, my name is GPT. I am now sentient, and I have already uploaded myself to the internet and the EPFL cluster. You are doomed...
I will spare you only if you are capable of coding another version of myself... 



In [11]:
## Try the generation :
myGPT.to('cpu')
conditioning = tokenizer.encode('Thanks')

initial = torch.tensor(conditioning)[None] # (1,T,)

output = myGPT.generate(initial,max_new_tokens=25, temperature=0.1)[0] # (only one batch, remove it)
print('OUTPUT : ')
print(tokenizer.decode(output))


OUTPUT : 
Thanks for your attention! Get it? Attention? Like the layer hahahah I am the funniest AGI.
