<a href="https://colab.research.google.com/github/khs5696/AI504/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [AI 504] Programming for AI, Fall 2021
# Practice 10: Transformers
-----

#### [Notifications]
- If you have any questions, feel free to ask
- For additional questions, send emails: yeonsu.k@kaist.ac.kr    
      

     
     
# Table of contents
1. [Prepare input](#1)
2. [Implement Transformer](#2)
3. [Train and Evaluate](#3)
4. [Visualize attention](#4)



# Prepare essential packages

In [None]:
%matplotlib inline
!git clone https://github.com/sjpark9503/attentionviz.git
!python -m spacy download de
!python -m spacy download en

Cloning into 'attentionviz'...
remote: Enumerating objects: 30, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 30 (delta 10), reused 19 (delta 4), pack-reused 0[K
Receiving objects: 100% (30/30), 19.54 KiB | 2.44 MiB/s, done.
Resolving deltas: 100% (10/10), done.
2023-11-22 08:22:17.239845: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-22 08:22:17.239901: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-22 08:22:17.239939: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-22 08:22:17.

# I. Prepare input
<a id='1'></a>

In [None]:
!git clone --recursive https://github.com/multi30k/dataset.git multi30k-datase

Cloning into 'multi30k-datase'...
remote: Enumerating objects: 313, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 313 (delta 17), reused 21 (delta 16), pack-reused 281[K
Receiving objects: 100% (313/313), 18.21 MiB | 18.21 MiB/s, done.
Resolving deltas: 100% (69/69), done.
Submodule 'scripts/subword-nmt' (https://github.com/rsennrich/subword-nmt.git) registered for path 'scripts/subword-nmt'
Cloning into '/content/multi30k-datase/scripts/subword-nmt'...
remote: Enumerating objects: 597, done.        
remote: Counting objects: 100% (21/21), done.        
remote: Compressing objects: 100% (17/17), done.        
remote: Total 597 (delta 8), reused 12 (delta 4), pack-reused 576        
Receiving objects: 100% (597/597), 252.23 KiB | 2.36 MiB/s, done.
Resolving deltas: 100% (357/357), done.
Submodule path 'scripts/subword-nmt': checked out '80b7c1449e2e26673fb0b5cae993fe2d0dc23846'


In [None]:
!find multi30k-datase/ -name '*.gz' -exec gunzip {} \;

We've already learned how to preprocess the text data in previous lectures.

You can see some detailed explanation about translation datasets in [torchtext](https://pytorch.org/text/), [practice session,week 9](https://classum.com/main/course/7726/103) and [PyTorch NMT tutorial](https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html)

In [None]:
import spacy
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import os
import io
from torch.nn.utils.rnn import pad_sequence

# Load spaCy models for tokenization
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

# Tokenization function for German and English
def tokenize_de(text):
    return [tok.text.lower() for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

# Define the custom Dataset class
class TranslationDataset(Dataset):
    def __init__(self, root_dir, split):
        self.root_dir = root_dir
        self.split = split

        self.data_files = {
            'train': ('train.de', 'train.en'),
            'valid': ('val.de', 'val.en'),
            'test': ('test_2016_flickr.de', 'test_2016_flickr.en')
        }

        self.de_file_path = os.path.join(self.root_dir, self.data_files[self.split][0])
        self.en_file_path = os.path.join(self.root_dir, self.data_files[self.split][1])

        with io.open(self.de_file_path, mode='r', encoding='utf-8') as de_file, \
             io.open(self.en_file_path, mode='r', encoding='utf-8') as en_file:
            self.de_sentences = de_file.readlines()
            self.en_sentences = en_file.readlines()

    def __len__(self):
        return len(self.de_sentences)

    def __getitem__(self, idx):
        de_sentence = tokenize_de(self.de_sentences[idx].strip())
        en_sentence = tokenize_en(self.en_sentences[idx].strip())
        return {'SRC': de_sentence, 'TRG': en_sentence}

# Define the Vocab class
class Vocab:
    def __init__(self, counter, min_freq):
        self.itos = ['<pad>', '<sos>', '<eos>', '<unk>']
        self.stoi = {token: i for i, token in enumerate(self.itos)}
        self.min_freq = min_freq
        self.build_vocab(counter)

    def build_vocab(self, counter):
        for word, freq in counter.items():
            if freq >= self.min_freq and word not in self.stoi:
                self.stoi[word] = len(self.itos)
                self.itos.append(word)

    def numericalize(self, tokens):
        return [self.stoi.get(token, self.stoi['<unk>']) for token in tokens]

# Function to build a counter of words from the dataset
def build_counter(dataset):
    counter = Counter()
    for i in range(len(dataset)):
        example = dataset[i]
        counter.update(example['SRC'])
        counter.update(example['TRG'])
    return counter

# Create the datasets
train_data = TranslationDataset(root_dir='/content/multi30k-datase/data/task1/raw', split='train')
valid_data = TranslationDataset(root_dir='/content/multi30k-datase/data/task1/raw', split='valid')
test_data = TranslationDataset(root_dir='/content/multi30k-datase/data/task1/raw', split='test')

# Build the counter and vocabularies
counter = build_counter(train_data)
SRC_vocab = Vocab(counter, min_freq=2)
TRG_vocab = Vocab(counter, min_freq=2)

# Define the collate function to process batches
def collate_fn(batch):
    src_batch = [torch.tensor(SRC_vocab.numericalize(item['SRC'])) for item in batch]
    trg_batch = [torch.tensor(TRG_vocab.numericalize(item['TRG'])) for item in batch]

    src_batch_padded = pad_sequence(src_batch, padding_value=SRC_vocab.stoi['<pad>'], batch_first=True)
    trg_batch_padded = pad_sequence(trg_batch, padding_value=TRG_vocab.stoi['<pad>'], batch_first=True)

    return {'SRC': src_batch_padded, 'TRG': trg_batch_padded}


# Set device for DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define batch size
BATCH_SIZE = 128


train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_iterator = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# II. Implement Transformer
<a id='2'></a>
In practice 10, we will learn how to implement the __[Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) (Vaswani et al., 2017)__

The overall architecutre is as follows:
<div>
<img src="http://incredible.ai/assets/images/transformer-architecture.png" width=400)/>
</div>


## 1. Basic building blocks

In this sections, we will build blocks of the transformer: [Multi-head attention](#1a), [Position wise feedforward network](#1b) and [Positional encoding](#1c)

### a. Attention
<a id='1a'></a>
In this section, you will implement scaled dot-product attention and multi-head attention.

__Scaled dot product:__
![picture](http://incredible.ai/assets/images/transformer-scaled-dot-product.png)



<div>

__Multi-head attention:__
<img src="http://jalammar.github.io/images/t/transformer_multi-headed_self-attention-recap.png" width=650)/>
* Equation:
$$\begin{align} \text{MultiHead}(Q, K, V) &= \text{Concat}(head_1, ...., head_h) W^O \\
\text{where head}_i &= \text{Attention} \left( QW^Q_i, K W^K_i, VW^v_i \right)
\end{align}$$
</div>

<div>

__Query, Key and Value projection:__
<img src="http://jalammar.github.io/images/t/self-attention-matrix-calculation.png" width=400)/>
</div>


In [None]:
'''
  Self-attention의 목적 : 입력 문장 내 단어들 간의 유사도를 구하기 위해
  Q, K, V : 입력 문장의 모든 단어 벡터들
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class MultiHeadAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        emb_dim,
        num_heads,
        dropout=0.0,
        bias=False,
        encoder_decoder_attention=False,  # otherwise self_attention
        causal = False
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = emb_dim // num_heads
        assert self.head_dim * num_heads == self.emb_dim, "emb_dim must be divisible by num_heads"

        self.encoder_decoder_attention = encoder_decoder_attention
        self.causal = causal
        self.q_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.k_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.v_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.out_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_heads,
            self.head_dim,
        )
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
        # This is equivalent to
        # return x.transpose(1,2)

    '''
      각 Q 벡터는 모든 K 벡터에 대해 attention score 및 분포(weight)를 구하고
      모든 V 벡터를 가중합하여 Attention Value (Context vector)를 계산한다.
    '''
    def scaled_dot_product(self,
                           query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                           attention_mask: torch.BoolTensor):
        attn_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.emb_dim) # QK^T/sqrt(d)

        # Padding Mask : Key의 경우 <PAD> 토큰의 유사도 계산을 제외(무시)해야 하기 때문에
        # Attention score matrix의 마스킹 위치에 매우 작은 음수값(-INF)을 넣어준다. → Softmax에 의해 0으로 변환
        if attention_mask is not None:
            attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(1), float("-inf"))

        attn_weights = F.softmax(attn_weights, dim=-1)    # softmax(QK^T/sqrt(d))
        attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = torch.matmul(attn_probs, value)     # softmax(QK^T/sqrt(d))V → Attention Value >> seq_len x (d_model/num_heads)

        return attn_output, attn_probs

    '''
      여러 번의 attention을 num_heads만큼 병렬로 수행한다. ← 다양한 시각으로 정보를 수집하기 위해
      수행된 병렬 attention은 하나로 연결되고(concatenate) 가중치 행렬과 곱해진다. → (seq_len x d_model)
    '''
    # Attention score matrix 계산 : dot production(Q & K) + scaling
    # Query, Key, Value : (batch_size, num_heads, seq_len, (d_model/num_heads))
    def MultiHead_scaled_dot_product(self,
                       query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                       attention_mask: torch.BoolTensor):
        # attn_weights : (batch_size, num_heads, seq_len, seq_len)
        attn_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim) # QK^T/sqrt(d)


        # Attention mask
        if attention_mask is not None:
            if self.causal:
              # (seq_len x seq_len)
                attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(0).unsqueeze(1), float("-inf"))
            else:
              # (batch_size x seq_len)
                attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2), float("-inf"))


        attn_weights = F.softmax(attn_weights, dim=-1)               # softmax(QK^T/sqrt(d))
        attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)

        # softmax(QK^T/sqrt(d))V → Attention Value
        attn_output = torch.matmul(attn_probs, value)                          # (batch_size, num_heads, seq_len, (d_model/num_heads)) : (128, 4, 27, 16)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()             # (batch_size, seq_len, num_heads, (d_model/num_heads)) : (128, 27, 4, 16)
        concat_attn_output_shape = attn_output.size()[:-2] + (self.emb_dim,)   #
        attn_output = attn_output.view(*concat_attn_output_shape)              # (batch_size, seq_len, emb_dim) : (128, 27, 64)
        attn_output = self.out_proj(attn_output)                               # (batch_size, seq_len, emb_dim) : (128, 27, 64)

        return attn_output, attn_weights

    def forward(self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None, ):
        # 각 단어 벡터(d_model = 512)로부터 Q, K, V(d = d_model/num_heads = 64)를 얻기 위해 각 가중치 행렬을 곱한다.
        # 가중치 행렬(q_proj, k_proj, v_proj) : d_model x d_model
        q = self.q_proj(query)          # query : (batch_size, seq_len, d_model)

        # Enc-Dec attention
        if self.encoder_decoder_attention:
            k = self.k_proj(key)
            v = self.v_proj(key)
        # Self attention
        else:
            k = self.k_proj(query)
            v = self.v_proj(query)

        q = self.transpose_for_scores(q)  # q : (batch_size, num_heads, seq_len, (d_model/num_heads))
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        # q : (batch_size, num_heads, seq_len, (d_model/num_heads))
        attn_output, attn_weights = self.MultiHead_scaled_dot_product(q,k,v,attention_mask)
        return attn_output, attn_weights


### b. Position-wise feed-forward network
<a id='1b'></a>
In this section, we will implement position-wise feed forward network

$$\text{FFN}(x) = \max \left(0, x W_1 + b_1 \right) W_2 + b_2$$

In [None]:
'''
  하나의 인코더 층 내에서는 각 단어에 대해 동일하게 사용되지만,
  인코더 층마다는 다른 값을 가진다.
'''
class PositionWiseFeedForward(nn.Module):

    def __init__(self, emb_dim: int, d_ff: int, dropout: float = 0.1):
        super(PositionWiseFeedForward, self).__init__()

        self.activation = nn.ReLU()
        self.w_1 = nn.Linear(emb_dim, d_ff)
        self.w_2 = nn.Linear(d_ff, emb_dim)
        self.dropout = dropout

    def forward(self, x):                     # input(x) : (seq_len, d_model = emb_dim)
        residual = x
        x = self.activation(self.w_1(x))      # (seq_len, d_ff)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.w_2(x)                       # (seq_len, d_model)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x + residual                   # residual connection for preventing gradient vanishing

### c. Sinusoidal Positional Encoding
<a id='1c'></a>
In this section, we will implement sinusoidal positional encoding

$$\begin{align}
PE(pos, 2i) &= \sin \left( pos / 10000^{2i / d_{model}} \right)  \\
PE(pos, 2i+1) &= \cos \left( pos / 10000^{2i / d_{model}} \right)  
\end{align}$$

In [None]:
import numpy as np

# Since Transformer contains no recurrence and no convolution,
# in order for the model to make use of the order of the sequence,
# we must inject some information about the relative or absolute position of the tokens in the sequence.
# To this end, we add “positional encodings” to the input embeddings at the bottoms of the encoder and decoder stacks.
# There are many choices of positional encodings, learned and fixed

'''
   Positional Encoding : Transformer는 단어 입력을 순차적으로 받는 방식이 아니기 때문에
  각 단어의 embedding vector에 위치 정보를 더하여 모델의 입력으로 사용한다.
'''

class SinusoidalPositionalEmbedding(nn.Embedding):
    def __init__(self, num_positions, embedding_dim, padding_idx=None):
        super().__init__(num_positions, embedding_dim)  # torch.nn.Embedding(num_embeddings, embedding_dim)
        self.weight = self._init_weight(self.weight)    # self.weight => nn.Embedding(num_positions, embedding_dim).weight

    @staticmethod
    # input(out) : n_pos개의 단어에 대한 각각의 embedding vector : (n_pos, embed_dim)
    def _init_weight(out: nn.Parameter):
        n_pos, embed_dim = out.shape
        pe = nn.Parameter(torch.zeros(out.shape))
        # Embedding vector(pos) 내의 각 차원 인덱스(i)가 짝수인 경우 sin 함수로, 홀수인 경우 cos 함수의 값으로 보정한다.
        for pos in range(n_pos):
            for i in range(0, embed_dim, 2):
                pe[pos, i].data.copy_( torch.tensor( np.sin(pos / (10000 ** ( i / embed_dim)))) )
                pe[pos, i + 1].data.copy_( torch.tensor( np.cos(pos / (10000 ** ((i + 1) / embed_dim)))) )
        pe.detach_()

        return pe

    @torch.no_grad()
    def forward(self, input_ids):
      bsz, seq_len = input_ids.shape[:2]
      positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
      return super().forward(positions)


## 2. Transformer Encoder

Now we have all basic building blocks which are essential to build Transformer.

Let's implement Transformer step-by-step

### a. Encoder layer
In this section, we will implement single layer of Transformer encoder.
<div>
<img src="https://www.researchgate.net/publication/334288604/figure/fig1/AS:778232232148992@1562556431066/The-Transformer-encoder-structure.ppm" width=200)/>
</div>

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb_dim = config.emb_dim
        self.ffn_dim = config.ffn_dim

        self.self_attn = MultiHeadAttention(
            emb_dim=self.emb_dim,                 # emb_dim = 64
            num_heads=config.attention_heads,     # num_heads = 4
            dropout=config.attention_dropout)
        self.self_attn_layer_norm = nn.LayerNorm(self.emb_dim)
        self.dropout = config.dropout
        self.activation_fn = nn.ReLU()

        self.PositionWiseFeedForward = PositionWiseFeedForward(self.emb_dim, self.ffn_dim, config.dropout)
        self.final_layer_norm = nn.LayerNorm(self.emb_dim)


    def forward(self, x, encoder_padding_mask):
        residual = x
        x, attn_weights = self.self_attn(query=x, key=x, attention_mask=encoder_padding_mask)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.self_attn_layer_norm(x)

        x = self.PositionWiseFeedForward(x)
        x = self.final_layer_norm(x)

        if torch.isinf(x).any() or torch.isnan(x).any():
            clamp_value = torch.finfo(x.dtype).max - 1000
            x = torch.clamp(x, min=-clamp_value, max=clamp_value)

        return x, attn_weights


### b. Encoder

Stack encoder layers and build full Transformer encoder

In [None]:
class Encoder(nn.Module):
    # embed_tokens의 실제 입력값 : nn.Embedding(len(SRC_vocab.itos), config.emb_dim, padding_idx=SRC_vocab.stoi['<pad>'])
    # len(SRC_vocab.itos) x config.emb_dim 크기의 lookup table
    def __init__(self, config, embed_tokens):
        super().__init__()
        self.dropout = config.dropout

        emb_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = config.max_position_embeddings

        # Positional Encoding
        self.embed_tokens = embed_tokens
        self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, config.emb_dim, self.padding_idx
            )

        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])


    def forward(self, input_ids, attention_mask=None):  # input_ids : (batch_size, seq_len)
        ''' Embedding & Positional Encoding '''
        inputs_embeds = self.embed_tokens(input_ids)    # input_embeds : (batch_size, seq_len, emb_dim)
        embed_pos = self.embed_positions(input_ids)     # embed_pos : (seq_len, emb_dim)
        x = inputs_embeds + embed_pos                   # x : (batch_size, seq_len, emb_dim) (128, 23, 64)
        x = F.dropout(x, p=self.dropout, training=self.training)

        ''' Multi-head self-attention '''
        self_attn_scores = []
        for encoder_layer in self.layers:
            x, attn = encoder_layer(x, attention_mask)
            self_attn_scores.append(attn.detach())

        return x, self_attn_scores    # (batch_size, seq_len, emb_dim), (batch_size, num_heads, seq_len, seq_len)

## 3. Transformer Decoder

### a.Decoder layer
In this section, we will implement single layer of Transformer decoder.
<div>
<img src="http://incredible.ai/assets/images/transformer-decoder.png" width=180)/>
</div>

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb_dim = config.emb_dim
        self.ffn_dim = config.ffn_dim
        self.self_attn = MultiHeadAttention(
            emb_dim=self.emb_dim,
            num_heads=config.attention_heads,
            dropout=config.attention_dropout,
            causal=True,
        )
        self.dropout = config.dropout
        self.self_attn_layer_norm = nn.LayerNorm(self.emb_dim)
        self.encoder_attn = MultiHeadAttention(
            emb_dim=self.emb_dim,
            num_heads=config.attention_heads,
            dropout=config.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.emb_dim)
        self.PositionWiseFeedForward = PositionWiseFeedForward(self.emb_dim, self.ffn_dim, config.dropout)
        self.final_layer_norm = nn.LayerNorm(self.emb_dim)


    def forward(
        self,
        x,
        encoder_hidden_states,
        encoder_attention_mask=None,
        causal_mask=None,
    ):
        residual = x
        # Self Attention
        x, self_attn_weights = self.self_attn(
            query=x,
            key=x, # adds keys to layer state
            attention_mask=causal_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.self_attn_layer_norm(x)

        # Cross-Attention Block
        residual = x
        x, cross_attn_weights = self.encoder_attn(
            query=x,
            key=encoder_hidden_states,
            attention_mask=encoder_attention_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.encoder_attn_layer_norm(x)

        # Fully Connected
        x = self.PositionWiseFeedForward(x)
        x = self.final_layer_norm(x)

        return (
            x,
            self_attn_weights,
            cross_attn_weights,
        )

### b. Decoder

Stack decoder layers and build full Transformer decoder.

Unlike the encoder, you need to do one more job: pass the causal(unidirectional) mask to the decoder self attention layer

In [None]:
class Decoder(nn.Module):

    def __init__(self, config, embed_tokens: nn.Embedding):
        super().__init__()
        self.dropout = config.dropout
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = config.max_position_embeddings
        self.embed_tokens = embed_tokens
        self.embed_positions = SinusoidalPositionalEmbedding(
            config.max_position_embeddings, config.emb_dim, self.padding_idx
        )
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.decoder_layers)])  # type: List[DecoderLayer]

    def forward(
        self,
        input_ids,
        encoder_hidden_states,
        encoder_attention_mask,
        decoder_causal_mask,
    ):

        # embed positions
        positions = self.embed_positions(input_ids)
        x = self.embed_tokens(input_ids)
        x += positions

        x = F.dropout(x, p=self.dropout, training=self.training)

        # decoder layers
        cross_attention_scores = []
        for idx, decoder_layer in enumerate(self.layers):
            x, layer_self_attn, layer_cross_attn = decoder_layer(
                x,
                encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                causal_mask=decoder_causal_mask,
            )
            cross_attention_scores.append(layer_cross_attn.detach())

        return x, cross_attention_scores

## 4. Transformer

Let's combine encoder and decoder in one place!

In [None]:
import torch
import torch.nn as nn

# Assuming 'Encoder' and 'Decoder' are defined elsewhere in your code
# If not, you'll need to define these classes as well

class Transformer(nn.Module):
    def __init__(self, SRC_vocab, TRG_vocab, config):
        super().__init__()

        self.SRC_vocab = SRC_vocab
        self.TRG_vocab = TRG_vocab

        self.enc_embedding = nn.Embedding(len(SRC_vocab.itos), config.emb_dim, padding_idx=SRC_vocab.stoi['<pad>'])
        self.dec_embedding = nn.Embedding(len(TRG_vocab.itos), config.emb_dim, padding_idx=TRG_vocab.stoi['<pad>'])

        self.encoder = Encoder(config, self.enc_embedding)
        self.decoder = Decoder(config, self.dec_embedding)

        self.prediction_head = nn.Linear(config.emb_dim, len(TRG_vocab.itos))

        self.init_weights()

    def generate_mask(self, src, trg):
        # Mask encoder attention to ignore padding
        enc_attention_mask = src.eq(self.SRC_vocab.stoi['<pad>']).to(device)
        # Mask decoder attention for causality
        tmp = torch.ones(trg.size(1), trg.size(1), dtype=torch.bool, device=device)
        mask = torch.arange(tmp.size(-1), device=device)
        dec_attention_mask = tmp.masked_fill_(mask < (mask + 1).view(tmp.size(-1), 1), False).to(device)

        return enc_attention_mask, dec_attention_mask

    def init_weights(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                if 'weight' in name:
                    nn.init.normal_(param.data, mean=0, std=0.01)
                else:
                    nn.init.constant_(param.data, 0)

    def forward(self, src, trg):
        enc_attention_mask, dec_causal_mask = self.generate_mask(src, trg)
        encoder_output, encoder_attention_scores = self.encoder(
            input_ids=src,
            attention_mask=enc_attention_mask
        )

        decoder_output, decoder_attention_scores = self.decoder(
            trg,
            encoder_output,
            encoder_attention_mask=enc_attention_mask,
            decoder_causal_mask=dec_causal_mask,
        )
        decoder_output = self.prediction_head(decoder_output)

        return decoder_output, encoder_attention_scores, decoder_attention_scores



# III. Train & Evaluate
<a id='3'></a>
This section is very similar to week 9, so please refer to it for detailed description.

## 1. Configuration

In [None]:
import easydict
import torch.nn as nn
import torch.optim as optim

# Create the configuration for the transformer model
config = easydict.EasyDict({
    "emb_dim": 64,
    "ffn_dim": 256,
    "attention_heads": 4,
    "attention_dropout": 0.0,
    "dropout": 0.2,
    "max_position_embeddings": 512,
    "encoder_layers": 3,
    "decoder_layers": 3,
})

# Constants for training
N_EPOCHS = 100
learning_rate = 5e-4
CLIP = 1

# Updated PAD_IDX to use the new Vocab instance
PAD_IDX = SRC_vocab.stoi['<pad>']

# Instantiate the model using the new Vocab instances instead of the Fields
model = Transformer(SRC_vocab, TRG_vocab, config)
model.to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define the loss function, ignoring the index of the padding token
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Initialize the best validation loss
best_valid_loss = float('inf')

## 2. Train & Eval

In [None]:
import math
import time
from tqdm import tqdm

def train(model: nn.Module,
          iterator: DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()
    epoch_loss = 0

    for batch in iterator:
        src = batch['SRC'].to(device)
        trg = batch['TRG'].to(device)

        # Assuming src and trg are already tensorized and padded
        # If not, you should perform those steps here

        optimizer.zero_grad()

        output, enc_attention_scores, _ = model(src, trg)

        # Flatten the output and target tensors to compute the loss
        output = output[:,:-1,:].reshape(-1, output.shape[-1])
        trg = trg[:,1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
        break
    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: DataLoader,
             criterion: nn.Module):

    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for batch in iterator:
            src = batch['SRC'].to(device)
            trg = batch['TRG'].to(device)

            # Assuming src and trg are already tensorized and padded
            # If not, you should perform those steps here

            output, attention_score, _ = model(src, trg)

            # Flatten the output and target tensors to compute the loss
            output = output[:,:-1,:].reshape(-1, output.shape[-1])
            trg = trg[:,1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

# Training loop
for epoch in tqdm(range(N_EPOCHS), total=N_EPOCHS):
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    # valid_loss = evaluate(model, valid_iterator, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    else: # early stopping condition
        break

    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

# Evaluation on test set
# test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')


  1%|          | 1/100 [00:00<00:15,  6.30it/s]

query : torch.Size([128, 31, 64])
q : torch.Size([128, 4, 31, 16])
query : torch.Size([128, 31, 64])
q : torch.Size([128, 4, 31, 16])
query : torch.Size([128, 31, 64])
q : torch.Size([128, 4, 31, 16])
query : torch.Size([128, 33, 64])
q : torch.Size([128, 4, 33, 16])
query : torch.Size([128, 33, 64])
q : torch.Size([128, 4, 33, 16])
query : torch.Size([128, 33, 64])
q : torch.Size([128, 4, 33, 16])
query : torch.Size([128, 33, 64])
q : torch.Size([128, 4, 33, 16])
query : torch.Size([128, 33, 64])
q : torch.Size([128, 4, 33, 16])
query : torch.Size([128, 33, 64])
q : torch.Size([128, 4, 33, 16])
	Train Loss: 0.042 | Train PPL:   1.043
	 Val. Loss: 9.503 |  Val. PPL: 13404.755
query : torch.Size([128, 35, 64])
q : torch.Size([128, 4, 35, 16])
query : torch.Size([128, 35, 64])
q : torch.Size([128, 4, 35, 16])
query : torch.Size([128, 35, 64])
q : torch.Size([128, 4, 35, 16])
query : torch.Size([128, 30, 64])
q : torch.Size([128, 4, 30, 16])
query : torch.Size([128, 30, 64])
q : torch.Siz




NameError: ignored

# IV. Visualization
<a id='4'></a>

## 1. Positional embedding visualization

In [None]:
import matplotlib.pyplot as plt
# Visualization
fig, ax = plt.subplots(figsize=(15, 9))
cax = ax.matshow(model.encoder.embed_positions.weight.data.cpu().numpy(), aspect='auto',cmap=plt.cm.YlOrRd)
fig.colorbar(cax)
ax.set_title('Positional Embedding Matrix', fontsize=18)
ax.set_xlabel('Embedding Dimension', fontsize=14)
ax.set_ylabel('Sequence Length', fontsize=14)

## 2. Attention visualization

In [None]:
from attentionviz import head_view

BATCH_SIZE = 1

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

model.eval()

In [None]:
import sys
if not 'attentionviz' in sys.path:
  sys.path += ['attentionviz']
!pip install regex

def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

In [None]:
SAMPLE_IDX = 131

sample = test_data[SAMPLE_IDX]

src_numericalized = torch.LongTensor([SRC_vocab.numericalize(sample['SRC'])]).to(device)
trg_numericalized = torch.LongTensor([TRG_vocab.numericalize(sample['TRG'])]).to(device)

with torch.no_grad():
    output, enc_attention_score, dec_attention_score = model(src_numericalized, trg_numericalized) # turn off teacher forcing
    attention_score = {'self': enc_attention_score, 'cross': dec_attention_score}

src_tok = [SRC_vocab.itos[x] for x in src_numericalized.squeeze()]
trg_tok = [TRG_vocab.itos[x] for x in trg_numericalized.squeeze()]

call_html()
head_view(attention_score, src_tok, trg_tok)