<a href="https://colab.research.google.com/github/mgp87/Jupyter_Notebooks_Collection/blob/main/transformerCourse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np

In [None]:
class ScaledDotProdAttention(nn.Module):
    def __init__(self, dropout=0.1):
      super(ScaledDotProdAttention, self).__init__()

      self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):

        attention_scores = torch.matmul(query, key.transpose(-2,-1)) / np.sqrt(query.size(-1))

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(attention_scores, dim=-1)

        attention = self.dropout(attention)

        return torch.matmul(attention, value), attention


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead
        self.d_v = d_model // nhead

        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)


        self.scaled_dot_prod_attention = ScaledDotProdAttention(dropout)

        self.linear_layer = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)


    def forward(self, query, key, value, mask=None, key_padding_mask=None):

        batch_size = query.size(0)

        query = self.linear_q(query).view(batch_size, -1, self.nhead, self.d_k).transpose(1,2)
        key = self.linear_q(key).view(batch_size, -1, self.nhead, self.d_k).transpose(1,2)
        value = self.linear_q(value).view(batch_size, -1, self.nhead, self.d_v).transpose(1,2)

        output, attention_scores = self.scaled_dot_prod_attention(query, key, value)

        output_concat = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)

        output_concat = self.linear_layer(output_concat)

        return self.dropout(output_concat)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class PoswiseFeedForward(nn.Module):
    def __init__(self, d_model, d_mlp=1024, dropout=0.1):
        super(PoswiseFeedForward, self).__init__()

        self.linear_1 = nn.Linear(d_model, d_mlp)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_mlp, d_model)

    def forward(self, x):

      x = self.linear_1(x)
      x = F.relu(x)
      x = self.dropout(x)
      x = self.linear_2(x)

      return x

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super(LayerNorm, self).__init__()

        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):

      mean = x.mean(dim=-1, keepdim=True)
      std = x.std(dim=-1, keepdim=True)

      x = (x - mean) / (std + self.eps)
      x = self.gamma * x + self.beta

      return x

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, nhead, d_mlp, dropout=0.1):
      super(EncoderBlock, self).__init__()


      self.multi_head_attention = MultiHeadAttention(d_model, nhead, dropout)

      self.feed_forward = PoswiseFeedForward(d_model, d_mlp, dropout)

      self.layer_norm1 = LayerNorm(d_model)
      self.layer_norm2 = LayerNorm(d_model)

      self.dropout1 = nn.Dropout(dropout)
      self.dropout2 = nn.Dropout(dropout)


    def forward(self, x, src_mask=None, src_key_padding_mask=None, is_causal=False):

        x2 = self.multi_head_attention(x, x, x, mask=src_mask, key_padding_mask=src_key_padding_mask)[0]

        x2 = self.layer_norm1(x2)

        x = x + self.dropout1(x2)

        x2 = self.feed_forward(x)

        x2 = self.layer_norm2(x2)

        x = x + self.dropout2(x2)

        return x



In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, nhead, d_mlp, dropout=0.1):
        super(DecoderBlock, self).__init__()


        self.masked_multi_head_attention = MultiHeadAttention(d_model, nhead, dropout)
        self.multi_head_attention = MultiHeadAttention(d_model, nhead, dropout)

        self.feed_forward = PoswiseFeedForward(d_model, d_mlp, dropout)

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)


    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):

        tgt2 = self.masked_multi_head_attention(tgt,tgt,tgt, mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]

        tgt2 = self.layer_norm1(tgt2)

        tgt = tgt + self.dropout1(tgt2)

        tgt2 = self.multi_head_attention(tgt2, memory, memory, mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]

        tgt2 = self.layer_norm2(tgt2)

        tgt = tgt + self.dropout2(tgt2)

        tgt2 = self.feed_forward(tgt)

        tgt2 = self.layer_norm3(tgt2)

        tgt = tgt + self.dropout3(tgt2)

        return tgt



In [None]:
class TransformerModel(nn.Module):
    def __init__(self, d_model, nhead, n_encoder, n_decoder, d_mlp, max_len, vocab_size, pad_idx, dropout=0.1):
        super(TransformerModel, self).__init__()

        self.d_model = d_model


        # Encoder
        encoder_layer = EncoderBlock(d_model, nhead, d_mlp, dropout)
        encoder_norm = LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, n_encoder, encoder_norm)


        # Decoder
        decoder_layer = DecoderBlock(d_model, nhead, d_mlp, dropout)
        decoder_norm = LayerNorm(d_model)
        self.decoder = nn.TransformerDecoder(decoder_layer, n_decoder, decoder_norm)


        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)

        # Embedding layers for input and output
        self.input_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.output_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)


        # Final linear layer
        self.linear = nn.Linear(d_model, vocab_size)


    def forward(self, src, output, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, is_causal=False):

        src = self.input_embedding(src) * np.sqrt(self.d_model)

        src = self.pos_encoder(src)

        encoder_outputs = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)


        output = self.output_embedding(output) * np.sqrt(self.d_model)

        output = self.pos_encoder(output)

        decoder_outputs = self.decoder(output, encoder_outputs, tgt_mask=tgt_mask, memory_mask=src_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)

        outputs = self.linear(decoder_outputs)

        return outputs

In [None]:
d_model = 512
nhead = 1
n_encoder_layers = 1
n_decoder_layers = 1
d_mlp = 1024
max_len = 6
vocab_size = len(list("abcdefghijklmnop"))
pad_idx = 0
dropout = 0.1

In [None]:
model = TransformerModel(d_model, nhead, n_encoder_layers, n_decoder_layers, d_mlp, max_len, vocab_size, pad_idx, dropout)

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
class ReverseDataset(Dataset):
    def __init__(self, length=10000, seq_len=10):

        self.length = length
        self.seq_len = seq_len
        self.vocab = list("abcdefghijklmnop")
        self.vocab_size = len(self.vocab)
        self.char_to_idx = {char: idx for idx, char in enumerate(self.vocab)}
        self.idx_to_char = {idx: char for idx, char in enumerate(self.vocab)}

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        sequence = torch.randint(high=self.vocab_size, size=(self.seq_len,))
        return sequence, torch.flip(sequence, dims=[0])


In [None]:
dataset = ReverseDataset(seq_len=max_len)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = TransformerModel(d_model, nhead, n_encoder_layers, n_decoder_layers, d_mlp, max_len, vocab_size, pad_idx, dropout).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def tokens_to_text(tokens, dataset):
    return ''.join(dataset.idx_to_char[token.item()] for token in tokens)

In [None]:
inputs, targets = next(iter(dataloader))
print("Input: ", tokens_to_text(inputs[4], dataset))
print("Target: ", tokens_to_text(targets[4], dataset))


Input:  kmkbpb
Target:  bpbkmk


In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    for i, (input, target) in enumerate(dataloader):

        input = input.T.to(device)
        target = target.T.to(device)

        taget_input = target[:-1, :]
        target_real = target[1:, :]

        output = model(input, target_real)

        loss = criterion(output.view(-1, vocab_size), target_real.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")


Epoch: 0, Iteration: 0, Loss: 2.484734296798706
Epoch: 0, Iteration: 100, Loss: 0.4349328875541687
Epoch: 0, Iteration: 200, Loss: 0.14970877766609192
Epoch: 0, Iteration: 300, Loss: 0.06762726604938507
Epoch: 0, Iteration: 400, Loss: 0.03968893364071846
Epoch: 0, Iteration: 500, Loss: 0.029801957309246063
Epoch: 0, Iteration: 600, Loss: 0.022669054567813873
Epoch: 0, Iteration: 700, Loss: 0.01714584417641163


KeyboardInterrupt: ignored

In [None]:
def output_to_text(output, dataset):

    tokens = F.softmax(output, dim=-1)

    tokens = torch.argmax(tokens, dim=-1)

    text = ''.join(dataset.idx_to_char[token.item()] for token in tokens)

    return text

In [None]:
inputs, targets = next(iter(dataloader))

index = 1

print("Input: ", tokens_to_text(inputs[index], dataset))
print("Target: ", tokens_to_text(targets[index], dataset))

input = inputs[index].T.to(device)
target = targets[index].T.to(device)
print(target)

Input:  hneobc
Target:  cboenh
tensor([ 2,  1, 14,  4, 13,  7], device='cuda:0')


In [None]:
output = model(input, target)

In [None]:
print("Input: ", tokens_to_text(inputs[index], dataset))
print("Prediction: ", output_to_text(output[index], dataset))

Input:  hneobc
Prediction:  cboenh
