<a href="https://colab.research.google.com/github/hyngon90/StatQuestTutorial/blob/main/06_encoder_decoder_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lightning
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L


class PositionEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()

        pe = torch.zeros(max_len, d_model)

        pos = torch.arange(start = 0, end = max_len, step=1).float().unsqueeze(1)
        idx = torch.arange(start = 0, end = d_model, step=2).float()

        div = 1/torch.tensor(10000.0)**(idx / d_model)

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        self.register_buffer('pe', pe)

    def forward(self, word_embeddings):

        return word_embeddings + self.pe[:word_embeddings.size(0), :]
class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.row_dim = 0
        self.col_dim = 1

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

            attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
            attention_score = torch.matmul(attention_percents, v)

        return attention_score
class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.row_dim = 0
        self.col_dim = 1

        self.W_q = nn.Linear(in_features = d_model, out_features = d_model, bias = False)
        self.W_k = nn.Linear(in_features = d_model, out_features = d_model, bias = False)
        self.W_v = nn.Linear(in_features = d_model, out_features = d_model, bias = False)

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask = None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0 = self.row_dim , dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask = mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
        attention_score = torch.matmul(attention_percents, v)

        return attention_score

class Encoder(nn.Module):
    def __init__(self, num_encoder_token, d_model, max_len):
        super().__init__()

        self.we = nn.Embedding(num_embeddings=num_encoder_token, embedding_dim=d_model)
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)
        self.self_attention = Attention(d_model=d_model)

    def forward(self, encoder_token_ids):
        word_embeddings = self.we(encoder_token_ids)
        position_encoded = self.pe(word_embeddings)

        self_attention_values = self.self_attention(position_encoded, position_encoded, position_encoded)
        residual_connection_values = position_encoded + self_attention_values

        return residual_connection_values

class Decoder(nn.Module):
    def __init__(self, num_decoder_token, d_model, max_len):
        super().__init__()

        self.we = nn.Embedding(num_embeddings=num_decoder_token, embedding_dim=d_model)
        # self.we = nn.Embedding(num_embeddings=2, embedding_dim=d_model)
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)
        self.masked_self_attention = Attention(d_model=d_model)
        self.encoder_decoder_attention = Attention(d_model=d_model)
        self.fc_layer = nn.Linear(in_features=d_model , out_features=num_decoder_token)

    def forward(self, decoder_token_ids, encoder_k, encoder_v):
        word_embeddings = self.we(decoder_token_ids)

        position_encoded = self.pe(word_embeddings)

        mask = torch.tril(torch.ones(decoder_token_ids.size(dim=0), decoder_token_ids.size(dim=0)))
        mask = mask.logical_not()

        masked_self_attention = self.masked_self_attention(position_encoded, position_encoded, position_encoded, mask = mask)
        encoder_decoder_attenton = self.encoder_decoder_attention(masked_self_attention, encoder_k, encoder_v)
        residual_connection_values = masked_self_attention + encoder_decoder_attenton
        fc_layer_output = self.fc_layer(residual_connection_values)

        return fc_layer_output

class Encoder_Decoder_Transformer(L.LightningModule):
    def __init__(self, num_token, d_model, max_len):
        super().__init__()
        self.learning_rate = 0.1

        self.encoder = Encoder(num_token, d_model, max_len)
        self.decoder = Decoder(num_token, d_model, max_len)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, token_ids):

        eos_token_id = 3
        eos_idx = (token_ids == eos_token_id).nonzero(as_tuple=True)[0]

        encoder_token_ids = token_ids[:int(eos_idx)]
        decoder_token_ids = token_ids[int(eos_idx):]

        encoder_k = self.encoder(encoder_token_ids)
        encoder_v = encoder_k
        decoded_values = self.decoder(decoder_token_ids, encoder_k, encoder_v)

        return decoded_values

    def configure_optimizers(self):
        return Adam(self.parameters(), lr = self.learning_rate)

    def training_step(self, batch, batch_idx):
        input, label = batch
        output = self.forward(input[0])
        loss = self.loss(output, label[0])

        return loss

'''
Translater English to Spanish

let's go
vamos

let's go <EOS> vamos <EOS>

'''
encoder_token_to_id = {
    "let's": 0,
    "to": 1,
    "go": 2,
    "<EOS>": 3,
}
encoder_id_to_token=dict(map(reversed, encoder_token_to_id.items()))
decoder_token_to_id ={
    "ir": 4,
    "vamos": 5,
    "y": 6,
    "<EOS>": 7,
}
decoder_id_to_token=dict(map(reversed, decoder_token_to_id.items()))

inputs = torch.tensor([
    [
        encoder_token_to_id["let's"],
        encoder_token_to_id["go"],
        encoder_token_to_id["<EOS>"],
        decoder_token_to_id["vamos"],
    ]
])
labels = torch.tensor([
    [
        # encoder_token_to_id["let's"],
        # encoder_token_to_id["go"],
        decoder_token_to_id["vamos"],
        decoder_token_to_id["<EOS>"],
    ]
])

dataset=TensorDataset(inputs,labels)
dataloader=DataLoader(dataset)

num_encoder_token = len(encoder_token_to_id)
num_decoder_token = len(decoder_token_to_id)
num_token = num_encoder_token + num_decoder_token
d_model = 2
max_len = 5

model = Encoder_Decoder_Transformer(num_token, d_model, max_len)
trainer = L.Trainer(max_epochs=30)
trainer.fit(model = model, train_dataloaders=dataloader)

input_token_ids = torch.tensor([
        encoder_token_to_id["let's"],
        encoder_token_to_id["go"],
        encoder_token_to_id["<EOS>"],
        # decoder_token_to_id["vamos"],
])

predicted_id = None
predicted_ids = torch.tensor([])

while (predicted_id != decoder_token_to_id["<EOS>"]) and (len(input_token_ids)<max_len):

    predictions = model(input_token_ids)
    predicted_id = torch.tensor([torch.argmax(predictions[-1, :])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
    input_token_ids = torch.cat((input_token_ids, predicted_id))

print("Predicted Tokens: ", end='')
for id in predicted_ids:
  print(decoder_id_to_token[id.item()],"", end='')

Collecting lightning
  Downloading lightning-2.5.0.post0-py3-none-any.whl.metadata (40 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/40.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.5.0.post0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Downloading torchmetrics-1.6.1-py3-none-any

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | Encoder          | 28     | train
1 | decoder | Decoder          | 64     | train
2 | loss    | CrossEntropyLoss | 0      | train
-----------------------------------------------------
92        Trainable params
0         Non-trainable params
92        Total params
0.000     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name    | Type             | Params | Mode 
-----------------------------------

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


Predicted Tokens: vamos <EOS> 