## Additive attention
Paper: [Neural machine translation by jointly learning to align and translate](https://arxiv.org/pdf/1409.0473) - Bahdanau et. al 2015

Dataset: [Multi30K English to Deutsche dataset](https://huggingface.co/datasets/bentrevett/multi30k)

Model: Use LSTM as encoder and decoder

#### Model variations
- Stacked LSTM encoder decoder
- BiLSTM encoder + LSTM decoder



In [None]:
from datasets import load_dataset

train_dataset = load_dataset("bentrevett/multi30k", split="train")
print(len(train_dataset))
print(train_dataset[0])
print(train_dataset.column_names)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

train.jsonl: 0.00B [00:00, ?B/s]

val.jsonl: 0.00B [00:00, ?B/s]

test.jsonl: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/29000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1014 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

29000
{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
['en', 'de']


In [None]:
import re

TOKEN_RE = re.compile(r"\w+|[^\w\s]")
def word_tokenize(text):
    text = text.lower().strip()
    return TOKEN_RE.findall(text)

In [None]:
word_tokenize("Hello, world!")

['hello', ',', 'world', '!']

In [None]:
from collections import Counter

PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"

SPECIAL_TOKENS = [PAD, BOS, EOS, UNK]

def build_vocab(tokenized_texts, max_vocab_size=10000, min_freq=3):
    counter = Counter(token for text in tokenized_texts for token in text)
    vocab = SPECIAL_TOKENS.copy()

    for token, freq in counter.most_common():
        if freq < min_freq:
            break
        if len(vocab) >= max_vocab_size:
            break
        if token not in vocab:
            vocab.append(token)

    vocab_to_index = {token: index for index, token in enumerate(vocab)}
    index_to_vocab = {index: token for token, index in vocab_to_index.items()}

    return vocab_to_index, index_to_vocab

In [None]:
def add_special_tokens(tokens):
    return [BOS] + tokens + [EOS]

In [None]:
def remove_special_tokens(tokens):
    return [token for token in tokens if token not in [PAD, BOS, EOS]]

def encode(token_to_index, text):
    return [token_to_index.get(token, token_to_index[UNK]) for token in text]

def decode(index_to_token, indices):
    return " ".join(remove_special_tokens([index_to_token.get(index, UNK) for index in indices]))


In [None]:
# Load data and build vocab

tokenized_train_dataset = train_dataset.map(lambda x: {"en": word_tokenize(x["en"]), "de": word_tokenize(x["de"])}, batched=False)

en_vocab_to_index, en_index_to_vocab = build_vocab(
    [item["en"] for item in tokenized_train_dataset]
)
de_vocab_to_index, de_index_to_vocab = build_vocab(
    [item["de"] for item in tokenized_train_dataset]
)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

In [None]:
print(f"English Vocab size: {len(en_vocab_to_index)}")
print(f"German Vocab size: {len(de_vocab_to_index)}")
print(tokenized_train_dataset[0])

English Vocab size: 4560
German Vocab size: 5422
{'en': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.'], 'de': ['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.']}


In [None]:
def preprocess(batch, source_lang, target_lang, source_vocab_to_index, target_vocab_to_index):
    source_encodings = [encode(source_vocab_to_index, add_special_tokens(word_tokenize(text))) for text in batch[source_lang]]
    target_encodings = [encode(target_vocab_to_index, add_special_tokens(word_tokenize(text))) for text in batch[target_lang]]

    return {"source": source_encodings, "target": target_encodings}


In [None]:
preprocessed_train_dataset = train_dataset.map(
    lambda x: preprocess(x, "de", "en", de_vocab_to_index, en_vocab_to_index),
    batched=True,
    remove_columns=["en", "de"]
)
preprocessed_train_dataset.set_format(type="torch", columns=["source", "target"])

print(preprocessed_train_dataset[0])
print(f"German: {decode(de_index_to_vocab, preprocessed_train_dataset[0]['source'].tolist())}")
print(f"English: {decode(en_index_to_vocab, preprocessed_train_dataset[0]['target'].tolist())}")

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
German: zwei junge weiße männer sind im freien in der nähe <unk> büsche .
English: two young , white males are outside near many bushes .


In [None]:
print(f"type: {type(preprocessed_train_dataset)}")
print(f"length: {len(preprocessed_train_dataset)}")
print(preprocessed_train_dataset[0])
print(preprocessed_train_dataset[1])
print(preprocessed_train_dataset[2])
print(preprocessed_train_dataset[3])

type: <class 'datasets.arrow_dataset.Dataset'>
length: 29000
{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
{'source': tensor([   1,   77,   31,   11,  831, 2082,    5,    3,    4,    2]), 'target': tensor([   1,  113,   30,    6,  325,  280,   17, 1180,    4,  712, 3814, 2644,
           5,    2])}
{'source': tensor([  1,   5,  67,  26, 226,   7,   5,   3,  58, 492,   4,   2]), 'target': tensor([   1,    4,   53,   33,  231,   69,    4,  248, 3815,    5,    2])}
{'source': tensor([  1,   5,  13,   7,   6,  47,  41,  30,  12,  14, 546,  10, 684,   5,
        250,   4,   2]), 'target': tensor([  1,   4,   9,   6,   4,  29,  23,  10,  36,   8,   4, 574, 575,   4,
        240,   5,   2])}


In [None]:
import torch

def pad_batch(sequences, pad_idx=0):
    dtype = sequences[0].dtype

    lengths = torch.tensor([len(seq) for seq in sequences])
    max_length = int(lengths.max().item())

    padded_batch = torch.full((len(sequences), max_length), pad_idx, dtype=dtype)
    for i, seq in enumerate(sequences):
        end = lengths[i]
        padded_batch[i, :end] = seq
    return padded_batch, lengths


In [None]:
def collate_fn(batch):
    source = [item["source"] for item in batch]
    target = [item["target"] for item in batch]

    source, source_lengths = pad_batch(source) # defer padding till batching
    target, target_lengths = pad_batch(target)

    return {"source": source, "source_lengths": source_lengths, "target": target, "target_lengths": target_lengths}


In [None]:
from torch.utils.data import DataLoader
loader = DataLoader(
    preprocessed_train_dataset,
    batch_size=3,
    shuffle=True,
    collate_fn=collate_fn,
)

batch = next(iter(loader))
print(batch["source"].shape)
print(batch["source_lengths"].shape)
print(batch["target"].shape)
print(batch["target_lengths"].shape)

torch.Size([3, 14])
torch.Size([3])
torch.Size([3, 15])
torch.Size([3])


In [None]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, device, padding_idx=0):
        super().__init__()
        self.directions = 2
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True, bidirectional=True)
        self.hidden_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.cell_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.device = device

    def forward(self, source_encodings, source_lengths): # source_encodings: (batch_size, max_length), source_lengths: (batch_size)
        B, T = source_encodings.size()
        h_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim, device=self.device)
        c_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim, device=self.device)

        embedded = self.embedding(source_encodings) # (B, T, embedding_dim)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, source_lengths.to('cpu'), batch_first=True, enforce_sorted=False)
        packed_outputs, (last_hidden, last_cell) = self.lstm(packed_embedded, (h_0, c_0)) # hidden: (num_layers * directions, B, hidden_dim)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True) # outputs: (B, T, hidden_dim * directions)

        h_f = last_hidden[0::2] # (num_layers, B, hidden_dim)
        h_b = last_hidden[1::2] # (num_layers, B, hidden_dim)
        h_cat = torch.cat((h_f, h_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension

        c_f = last_cell[0::2] # (num_layers, B, hidden_dim)
        c_b = last_cell[1::2] # (num_layers, B, hidden_dim)
        c_cat = torch.cat((c_f, c_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension

        last_hidden = torch.tanh(self.hidden_projection(h_cat)) # (num_layers, B, hidden_dim)
        last_cell = torch.tanh(self.cell_projection(c_cat)) # (num_layers, B, hidden_dim)

        return outputs, last_hidden, last_cell


In [None]:
encoder = Encoder(5, 20, 32, 2, 0.0, 'cpu')

random_source_encodings = torch.randint(0, 5, (3, 5))
random_source_lengths = torch.tensor([5, 4, 3])

outputs, last_hidden, last_cell = encoder(random_source_encodings, random_source_lengths)

print(f"outputs (B, T, hidden_dim * directions): {outputs.shape}")
print(f"last_hidden (num_layers * directions, B, hidden_dim): {last_hidden.shape}")
print(f"last_cell (num_layers * directions, B, hidden_dim): {last_cell.shape}")

print(f"last_hidden: {last_hidden}")


outputs (B, T, hidden_dim * directions): torch.Size([3, 5, 64])
last_hidden (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 32])
last_cell (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 32])
last_hidden: tensor([[[ 3.6009e-02, -6.1706e-02,  3.7128e-02,  1.0430e-02, -1.1734e-01,
           8.3199e-02, -1.0247e-01,  5.4959e-02,  1.0333e-01,  7.5510e-02,
          -8.7594e-02,  2.2755e-02,  6.8861e-02,  1.6432e-02,  7.1154e-02,
          -6.2368e-02, -2.7067e-02,  9.2731e-02, -2.6508e-02,  7.7278e-02,
          -7.7838e-02, -7.6362e-02, -1.4391e-01, -8.7677e-02, -1.4580e-02,
           8.0418e-02,  8.8847e-02, -1.3989e-02,  1.3771e-01, -1.6716e-01,
          -7.0500e-02, -3.2835e-02],
         [ 7.9195e-02, -1.5790e-01,  1.0455e-01, -5.9135e-02, -1.7228e-01,
           3.7186e-02, -8.3966e-02,  1.0417e-01,  1.3150e-01,  6.0524e-02,
          -1.5324e-01,  2.5763e-02, -1.7280e-04,  2.7339e-02,  1.8096e-02,
           4.6769e-02, -3.8586e-02,  7.8259e-02, -4.1275e-0

In [None]:
encoder = Encoder(len(de_vocab_to_index), 120, 256, 2, 0.0, 'cpu')

outputs, last_hidden, last_cell = encoder(batch["source"], batch["source_lengths"])

print(f"source lengths: {batch['source_lengths']}")
print(f"outputs (B, T, hidden_dim * directions): {outputs.shape}")
print(f"last_hidden (num_layers * directions, B, hidden_dim): {last_hidden.shape}")
print(f"last_cell (num_layers * directions, B, hidden_dim): {last_cell.shape}")

print(outputs[0][10]) # hidden state for padded token after re padding

source lengths: tensor([14, 10, 14])
outputs (B, T, hidden_dim * directions): torch.Size([3, 14, 512])
last_hidden (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 256])
last_cell (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 256])
tensor([-9.7102e-03, -2.7296e-02, -1.3589e-02, -8.5130e-02, -3.7030e-02,
        -1.0311e-01,  5.2920e-02, -6.1724e-02,  2.3207e-02, -3.3429e-02,
        -4.8263e-02, -3.7588e-02, -4.7728e-02,  2.6776e-02,  6.9178e-02,
         8.7372e-02,  4.1226e-03,  7.3281e-02,  6.9871e-02,  2.4047e-02,
        -1.0090e-02,  2.1739e-02, -3.4908e-02,  1.1700e-02, -2.6054e-02,
        -3.5538e-02, -5.5508e-02,  2.8765e-02,  2.0604e-03,  5.7185e-02,
        -1.5269e-02, -5.4489e-02,  4.9306e-02,  3.4461e-02,  3.1712e-02,
        -5.0002e-02, -6.4453e-02, -9.1815e-02, -1.6711e-03, -1.7521e-02,
         6.4264e-02, -2.3993e-02,  4.8036e-03, -8.7409e-02, -9.6848e-03,
         2.0487e-02, -3.7851e-02, -2.2955e-02,  3.6448e-03,  3.1409e-02,
         5.55

In [None]:
class Decoder(nn.Module):
    """
    Decoder for the Seq2Seq model. This works on a batch of single step targets (B, 1).
    It doesn't take the entire target sequence, because the decision to teacher force and what kind of search strategy to use
    is done at the sequence level, not the step level.
    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, padding_idx=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False,
        )
        self.encoder_hidden_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.encoder_cell_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.output_projection = nn.Linear(hidden_dim, vocab_size)

    def forward(self, target, encoder_last_hidden, encoder_last_cell):
        # Normalize target shape to (B, 1)
        if target.dim() == 0:
            target = target.view(1, 1)
        else:
            target = target.unsqueeze(1)

        target = target.long()

        embedded = self.embedding(target)  # (B, T, embedding_dim)
        # use dropout on the embedded input later
        outputs, (hidden, cell) = self.lstm(embedded, (encoder_last_hidden, encoder_last_cell)) # (B, 1, hidden_dim)
        logits = self.output_projection(outputs) # (B, 1, vocab_size)

        return logits, hidden, cell

In [None]:
decoder = Decoder(len(en_vocab_to_index), 120, 256, 2, 0.0)

# One decoding step for the whole batch (B,)
step0 = batch["target"][:, 0]
print(step0.shape, step0)

logits, hidden, cell = decoder(step0, last_hidden, last_cell)
print(f"logits (B, 1, vocab_size): {logits.shape}")
print(f"hidden (num_layers, B, hidden_dim): {hidden.shape}")
print(f"cell (num_layers, B, hidden_dim): {cell.shape}")
print(logits)


torch.Size([3]) tensor([1, 1, 1])
logits (B, 1, vocab_size): torch.Size([3, 1, 4560])
hidden (num_layers, B, hidden_dim): torch.Size([2, 3, 256])
cell (num_layers, B, hidden_dim): torch.Size([2, 3, 256])
tensor([[[ 0.0242,  0.0007, -0.0329,  ...,  0.0088,  0.0172, -0.0255]],

        [[ 0.0272,  0.0004, -0.0318,  ...,  0.0049,  0.0145, -0.0269]],

        [[ 0.0270,  0.0007, -0.0363,  ...,  0.0142,  0.0215, -0.0310]]],
       grad_fn=<ViewBackward0>)


In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, bos_idx, eos_idx, max_target_length, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.max_target_length = max_target_length
        self.device = device

    def forward(self, source, source_lengths, target=None):
        B = source.shape[0]


        encoder_outputs, hidden, cell = self.encoder(source, source_lengths)

        if target is not None:
            steps = target.shape[1] - 1
            inputs = target[:, 0]
        else:
            steps = self.max_target_length
            inputs = torch.full((B,), self.bos_idx, dtype=torch.long)

        logits_all = torch.zeros(B, steps, self.decoder.vocab_size, device=self.device)
        preds_all = torch.zeros(B, steps, device=self.device)

        for t in range(steps):
            logits, hidden, cell = self.decoder(inputs, hidden, cell)
            step_logits = logits.squeeze(1) # (B, 1, vocab_size) -> (B, vocab_size)
            step_preds = step_logits.argmax(dim=1)

            logits_all[:, t, :] = step_logits
            preds_all[:, t] = step_preds

            # Teacher forcing
            if target is not None:
                inputs = target[:, t + 1]
            else:
                inputs = step_preds
                if step_preds.eq(self.eos_idx).all():
                    break

        return logits_all, preds_all


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=de_vocab_to_index[PAD])

In [None]:
model = Seq2Seq(encoder, decoder, de_vocab_to_index[BOS], de_vocab_to_index[EOS], 30, 'cpu')

logits_all, preds_all = model(batch["source"], batch["source_lengths"], batch["target"])
target = batch["target"][:, 1:].reshape(-1)
loss = criterion(logits_all.reshape(-1, logits_all.shape[-1]), target)


print(f"logits_all.shape: {logits_all.shape}")
print(f"preds_all.shape: {preds_all.shape}")
print(f"batch['target'].shape: {batch['target'].shape}")
print(f"loss: {loss}\n")

for i in range(len(batch.items())-1):
    print(f"Source: {decode(de_index_to_vocab, batch['source'][i].tolist())}")
    print(f"Target: {decode(en_index_to_vocab, batch['target'][i].tolist())}")
    print(f"Pred: {decode(en_index_to_vocab, preds_all[i].tolist())}")
    print("-"*100)

logits_all.shape: torch.Size([3, 14, 4560])
preds_all.shape: torch.Size([3, 14])
batch['target'].shape: torch.Size([3, 15])
loss: 8.43736743927002

Source: eine gruppe von kindern und zwei erwachsene sitzen auf einer bank .
Target: a group of children and two adults are sitting on a bench .
Pred: bamboo muscular cookie slide slide muscular guarding rollerskates cookie cookie cookie cookie chase muscular
----------------------------------------------------------------------------------------------------
Source: zwei männer an einer ampel schneiden grimassen .
Target: two men on a traffic light making faces .
Pred: muscular tied flour flour guarding salmon salmon slide slide salmon crawls crawls slide slide
----------------------------------------------------------------------------------------------------
Source: ein mann fährt auf einem steilen geländer neben einer treppe skateboard .
Target: a man skateboards down a steep railing next to some steps .
Pred: bamboo muscular cookie enter

In [None]:
tgt_vocab_size = len(en_vocab_to_index) # 4560
src_vocab_size = len(de_vocab_to_index) # 5422
emb_dim = 512
hidden_dim = 512
enc_num_layers = 2
dec_num_layers = 2
enc_dropout = 0.0
dec_dropout = 0.0
padding_idx = de_vocab_to_index[PAD]
bos_idx = de_vocab_to_index[BOS]
eos_idx = de_vocab_to_index[EOS]
max_target_length = 30

batch_size = 64
learning_rate = 0.0005
epochs = 10

In [None]:
encoder = Encoder(src_vocab_size, emb_dim, hidden_dim, enc_num_layers, enc_dropout, device, padding_idx)
decoder = Decoder(tgt_vocab_size, emb_dim, hidden_dim, dec_num_layers, dec_dropout, padding_idx)

model = Seq2Seq(encoder, decoder, bos_idx, eos_idx, max_target_length, device)
model.to(device)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(5422, 512, padding_idx=0)
    (lstm): LSTM(512, 512, num_layers=2, batch_first=True, bidirectional=True)
    (hidden_projection): Linear(in_features=1024, out_features=512, bias=True)
    (cell_projection): Linear(in_features=1024, out_features=512, bias=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(4560, 512, padding_idx=0)
    (lstm): LSTM(512, 512, num_layers=2, batch_first=True)
    (encoder_hidden_projection): Linear(in_features=1024, out_features=512, bias=True)
    (encoder_cell_projection): Linear(in_features=1024, out_features=512, bias=True)
    (output_projection): Linear(in_features=512, out_features=4560, bias=True)
  )
)

### Model parameter calculation

#### Encoder
- source embedding: vocab * emb_dim + bias
- 2 layers BiLSTM:  
    - direction * gates * ( W_x + W_h + b_x + b_h)
    - direction * gates * ( W_h + W_h + b_h + b_h)
- hidden projection: 2 * (2 * hidden_dim * hidden_dim + b_h)

#### Decoder
- target embedding: vocab * emb_dim + bias
- 2 layers LSTM:
    - gates * ( W_x + W_h + b_x + b_h)
    - gates * ( W_x + W_h + b_x + b_h)
- output projection: hidden_dim * target_vocab + b_v

In [None]:
p_src_emd = src_vocab_size * emb_dim + src_vocab_size
p_enc_lstm_1 = 2 * 4 * (emb_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_enc_lstm_2 = 2 * 4 * (hidden_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_enc_hidden_proj = 2 * (2 * hidden_dim * hidden_dim + hidden_dim)
p_enc_cell_proj = 2 * (2 * hidden_dim * hidden_dim + hidden_dim)

p_encoder = p_src_emd + p_enc_lstm_1 + p_enc_lstm_2 + p_enc_hidden_proj + p_enc_cell_proj

print(f"Encoder parameters: {p_encoder:,}")

p_tgt_emd = tgt_vocab_size * emb_dim + tgt_vocab_size
p_dec_lstm_1 = 4 * (emb_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_dec_lstm_2 = 4 * (hidden_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_dec_output_proj = hidden_dim * tgt_vocab_size + tgt_vocab_size

p_decoder = p_tgt_emd + p_dec_lstm_1 + p_dec_lstm_2 + p_dec_output_proj

print(f"Decoder parameters: {p_decoder:,}")

p_seq2seq = p_encoder + p_decoder

print(f"Seq2Seq parameters: {p_seq2seq:,}")



Encoder parameters: 13,334,830
Decoder parameters: 8,905,632
Seq2Seq parameters: 22,240,462


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The Seq2Seq model has {count_parameters(model):,} trainable parameters')
print(f"Encoder parameters: {count_parameters(encoder):,}")
print(f"Decoder parameters: {count_parameters(decoder):,}")


The Seq2Seq model has 24,253,904 trainable parameters
Encoder parameters: 14,327,808
Decoder parameters: 9,926,096


### Rough estimate of memory needed
- long or float32 = 4 bytes
- memory for model for Adam optimizer = 4 * mode parameters * 4 bytes
- activations of LSTM = 6 * hidden_dim * num_layers * batch_size * seq_len * 4 bytes
- activations embedding = batch_size * seq_len * embed_dim * 2
- hidden projection activations = batch_size * seq_len * hidden_dim * 2
- output project activation = batch_size * seq_len * tgt_vocab

total = sum of above + 30%

In [None]:
model_memory = 4 * 4 * p_seq2seq
activations_memory = 6 * hidden_dim * (enc_num_layers + dec_num_layers) * batch_size * max_target_length * 4
embedding_memory = batch_size * max_target_length * emb_dim * 2
hidden_projection_memory = batch_size * max_target_length * hidden_dim * 2
output_projection_memory = batch_size * max_target_length * tgt_vocab_size

total_activations_memory = activations_memory + embedding_memory + hidden_projection_memory + output_projection_memory
total_memory = model_memory + total_activations_memory


print(f"Total memory: {total_memory:,} bytes")
print(f"Model memory: {model_memory:,} bytes")
print(f"Activations memory: {activations_memory + embedding_memory + hidden_projection_memory + output_projection_memory:,} bytes")
print(f"Total required memory: {total_memory * 1.3:,} bytes")

Total memory: 462,906,592 bytes
Model memory: 355,847,392 bytes
Activations memory: 107,059,200 bytes
Total required memory: 601,778,569.6 bytes


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def batch_to_device(batch, device):
    return {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
     for k, v in batch.items()}

In [None]:
def train(model, loader, criterion, optimizer, epochs):
    model.train()
    epoch_loss = 0
    for batch in loader:
        batch = batch_to_device(batch, device)
        optimizer.zero_grad()

        logits_all, preds_all = model(batch["source"], batch["source_lengths"], batch["target"])
        target = batch["target"][:, 1:].reshape(-1)

        loss = criterion(logits_all.reshape(-1, logits_all.shape[-1]), target)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)

In [None]:
@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    epoch_loss = 0
    for batch in loader:
        batch = batch_to_device(batch, device)
        logits_all, preds_all = model(batch["source"], batch["source_lengths"], batch["target"])
        target = batch["target"][:, 1:].reshape(-1)

        loss = criterion(logits_all.reshape(-1, logits_all.shape[-1]), target)
        epoch_loss += loss.item()

    return epoch_loss / len(loader)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    minutes = int(elapsed_time / 60)
    seconds = int(elapsed_time % 60)
    return minutes, seconds

In [None]:
from google.colab import drive
drive.mount('/content/drive')

model_path = '/content/drive/My Drive/ML study/Attentions/additive attention'

Mounted at /content/drive


In [None]:
train_dataset = load_dataset("bentrevett/multi30k", split="train")
val_dataset = load_dataset("bentrevett/multi30k", split="validation")

final_train_dataset = train_dataset.map(
    lambda x: preprocess(x, "de", "en", de_vocab_to_index, en_vocab_to_index),
    batched=True,
    remove_columns=["en", "de"]
)
final_train_dataset.set_format(type="torch", columns=["source", "target"])

final_val_dataset = val_dataset.map(
    lambda x: preprocess(x, "de", "en", de_vocab_to_index, en_vocab_to_index),
    batched=True,
    remove_columns=["en", "de"]
)
final_val_dataset.set_format(type="torch", columns=["source", "target"])

pin_memory = True if device == "cuda" else False

train_loader = DataLoader(final_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=pin_memory, persistent_workers=True)
val_loader = DataLoader(final_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=pin_memory, persistent_workers=True)

In [None]:
import time

best_valid_loss = float('inf')

for e in range(epochs):
    start_time = time.time()
    train_loss = train(model, train_loader, criterion, optimizer, e)
    val_loss = validate(model, val_loader, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if val_loss < best_valid_loss:
        best_valid_loss = val_loss
        torch.save(model.state_dict(), f'{model_path}additive-attention-en-de-translator-model-512-epoch-{e+1}.pt')

    print(f"Epoch {e+1}: Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, Epoch time: {epoch_mins}m {epoch_secs}s")

Epoch 1: Train loss: 3.4143, Val loss: 3.1912, Epoch time: 0m 41s
Epoch 2: Train loss: 3.0845, Val loss: 2.9535, Epoch time: 0m 44s
Epoch 3: Train loss: 2.8305, Val loss: 2.7915, Epoch time: 0m 41s
Epoch 4: Train loss: 2.6187, Val loss: 2.6703, Epoch time: 0m 41s
Epoch 5: Train loss: 2.4271, Val loss: 2.5842, Epoch time: 0m 41s
Epoch 6: Train loss: 2.2448, Val loss: 2.5368, Epoch time: 0m 41s
Epoch 7: Train loss: 2.0729, Val loss: 2.5136, Epoch time: 0m 41s
Epoch 8: Train loss: 1.9057, Val loss: 2.4924, Epoch time: 0m 42s
Epoch 9: Train loss: 1.7443, Val loss: 2.5200, Epoch time: 0m 42s
Epoch 10: Train loss: 1.5916, Val loss: 2.5361, Epoch time: 0m 41s
