In [1]:
import sys
sys.executable

'/Users/b/ml/seq2seq/venv/bin/python'

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import spacy
import datasets
import tqdm
import evaluate
import matplotlib.pyplot as pyplot
import matplotlib.ticker as ticker

In [2]:
seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
# dataset found at bentrevett/multi30k
dataset = datasets.load_dataset("bentrevett/multi30k")
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [4]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [5]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])]
    de_tokens = [token.text for token in de_nlp.tokenizer(example["de"])]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
        de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {"en_tokens": en_tokens, "de_tokens": de_tokens}
        

In [6]:
max_length = 1000
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "de_nlp": de_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

In [7]:
from collections import Counter, defaultdict

class Vocab:
    def __init__(self, token_to_index, unk_token="<unk>"):
        self.token_to_index = token_to_index
        self.index_to_token = {idx:token for token,idx in token_to_index.items()}
        self.unk_token = unk_token
        self.unk_index = token_to_index[unk_token]

    def __len__(self):
        return len(self.token_to_index)

    def __getitem__(self, token):
        return self.token_to_index.get(token, self.unk_index)

    def token_to_idx(self, token):
        return self.__getitem__(token)

    def idx_to_token(self, index):
        return self.index_to_token.get(index, self.unk_token)

    def set_default_token(self, index):
        if index == self.unk_index:
            return

        current_token = self.index_to_token.get(index, None)

        if self.unk_token in self.token_to_index and current_token:
            self.token_to_index[self.unk_token], self.token_to_index[current_token] = index, self.unk_index

        if self.unk_index in self.index_to_token and index in self.index_to_token:
            self.index_to_token[self.unk_index], self.index_to_token[index] = self.index_to_token[index], self.index_to_token[self.unk_index]

        self.unk_index = index

    def get_stoi(self):
        return self.token_to_index

    def get_itos(self):
        return self.index_to_token

    def lookup_tokens(self, indices):
        if torch.istensor(indices):
            indices = indices.tolist()

        return [self.idx_to_token(index) for index in indices]

    def lookup_indices(self, tokens):
        return [self.token_to_idx(token) for token in tokens]





In [8]:
def build_vocab_from_iterator(iterator, min_freq=1, specials=None):
    counter = Counter()
    for token in iterator:
        counter.update(token)

    token_to_index = {}
    if specials:
        for idx, token in enumerate(specials):
            token_to_index[token] = idx

    for token, freq in counter.items():
        if freq >= min_freq and token not in token_to_index:
            token_to_index[token] = len(token_to_index)

    unk_token = specials[0] if specials else "<unk>"
    if unk_token not in token_to_index:
        token_to_index[unk_token] = len(token_to_index)

    return Vocab(token_to_index, unk_token=unk_token)


tokens = [["death", "destruction", "domination", "weapons", "graves"], ["all", "you", "care", "about", "is", "love"]]
min_freq = 1
specials = ["<unk>", "<pad>", "<sos>", "<eos>"]
vocab = build_vocab_from_iterator(tokens, min_freq, specials)
print(vocab.get_stoi())
print(vocab.idx_to_token(6))

    

{'<unk>': 0, '<pad>': 1, '<sos>': 2, '<eos>': 3, 'death': 4, 'destruction': 5, 'domination': 6, 'weapons': 7, 'graves': 8, 'all': 9, 'you': 10, 'care': 11, 'about': 12, 'is': 13, 'love': 14}
domination


In [9]:
min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"

special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token,
]

en_vocab = build_vocab_from_iterator(
    train_data["en_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)

de_vocab = build_vocab_from_iterator(
    train_data["de_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)

In [10]:
assert en_vocab[unk_token] == de_vocab[unk_token]
assert de_vocab[pad_token] == en_vocab[pad_token]

unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]

In [11]:
en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)

In [12]:
def numericalize_example(example, en_vocab, de_vocab):
    en_ids = en_vocab.lookup_indices(example["en"])
    de_ids = de_vocab.lookup_indices(example["de"])
    return {"en_ids": en_ids, "de_ids": de_ids}

In [13]:
fn_kwargs = {"en_vocab": en_vocab, "de_vocab": de_vocab}

train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)

In [14]:
data_type = "torch"
format_columns = ["en_ids", "de_ids"]

train_data = train_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

valid_data = valid_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

test_data = test_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)


In [15]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }
        return batch

    return collate_fn

In [16]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return data_loader

In [17]:
batch_size = 128

train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
vaild_data_loader = get_data_loader(valid_data, batch_size, pad_index)
test_data_loader = get_data_loader(test_data, batch_size, pad_index)

In [18]:
class Encoder(nn.Module):
    def __init__(
        self, input_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout
    ):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, encoder_hidden_dim, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = [src length, batch size]
        embedded = self.dropout(self.embedding(src))
        # embedded = [src length, batch size, embedding dim]
        outputs, hidden = self.rnn(embedded)
        # outputs = [src length, batch size, encoder_hidden_dim * 2]
        # hidden = [n layers * n directions, batch_size, encoder_hidden_dim]
        hidden = torch.tanh(
            self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        )
        # hidden = [batch size, decoder_hidden_dim]
        return outputs, hidden




In [1]:
class Attention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
        super().__init__()
        self.attn_fc = nn.Linear(
            (encoder_hidden_dim * 2) + decoder_hidden_dim, decoder_hidden_dim
        )
        self.v_fc = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        src_length = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(1).repeat(1, src_length, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        energy = torch.tanh(self.attn_fc(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v_fc(energy).squeeze(2)
        return softmax(attention, dim=1)


NameError: name 'nn' is not defined

In [6]:
class Decoder(nn.Module):
    def __init__(
        self,
        output_dim,
        embedding_dim,
        encoder_hidden_dim,
        decoder_hidden_dim,
        attention,
        dropout
    ):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.GRU((encoder_hidden_dim * 2), decoder_hidden_dim)
        self.fc_out = nn.Linear(
            (encoder_hidden_dim * 2) + decoder_hidden_dim + embedding_hidden_dim, output_dim
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedded(input))
        a = self.attention(hidden, encoder_outputs)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        hidden = hidden.unsqueeze(0)
        output, hidden = self.rnn(torch.cat((embedded, weighted), dim=2), hidden)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        embedded = embedded.squeeze(0)
        predictions = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
        return predictions, hidden.squeeze(0), a.squeeze(1)
        


NameError: name 'nn' is not defined

In [21]:
# again
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio):
        batch_size = src.shape[1]
        trg_length = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_length, batch_size, trg_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(src)
        input = trg[0, :]
        for t in range(1, trg_length):
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            outputs[t] = output
            teacher_forcing = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_forcing else top1
        return outputs
            

In [22]:
input_dim = len(de_vocab)
output_dim = len(en_vocab)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
encoder_hidden_dim = 512
decoder_hidden_dim = 512
encoder_dropout = 0.5
decoder_dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

attention = Attention(encoder_hidden_dim, decoder_hidden_dim)

encoder = Encoder(
    input_dim,
    encoder_embedding_dim,
    encoder_hidden_dim,
    decoder_hidden_dim,
    encoder_dropout,
)

decoder = Decoder(
    output_dim,
    decoder_embedding_dim,
    encoder_hidden_dim,
    decoder_hidden_dim,
    decoder_dropout,
    attention,
)

model = Seq2Seq(encoder, decoder, device).to(device)


In [23]:
def init_weights(m):
    for name, param in m.named_parameters():
        if "weight" in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7853, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn_fc): Linear(in_features=1536, out_features=512, bias=True)
      (v_fc): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(5893, 256)
    (rnn): GRU(1280, 512)
    (fc_out): Linear(in_features=1792, out_features=5893, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [24]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"the model has {count_parameters(model):,} trainable parameters.")

the model has 20,518,405 trainable parameters.


In [25]:
optimizer = optim.Adam(model.parameters())

In [26]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)

In [30]:
# again
import logging

logging.basicConfig(level=logging.DEBUG)

def train_fn(
    model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device
):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(data_loader):
        logging.debug(f"Processing batch {i}")
        src = batch["de_ids"].to(device)
        trg = batch["en_ids"].to(device)
        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio)

        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)


In [31]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            src = batch["de_ids"].to(device)
            trg = batch["en_ids"].to(device)
            output = model(src, trg, 0)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)



In [None]:
n_epochs = 10
clip = 1.0
teacher_forcing_ratio = 0.5

best_valid_loss = float("inf")

for epoch in tqdm.tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        teacher_forcing_ratio,
        device
    )
    valid_loss = evaluate_fn(
        model,
        valid_data_loader,
        criterion,
        device
    )
    if valid_loss < best_valid_loss: 
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "tut3.1-model.pt")

    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")



#print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
#print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")



  0%|                                                    | 0/10 [00:00<?, ?it/s]DEBUG:root:Processing batch 0
DEBUG:root:Processing batch 1
DEBUG:root:Processing batch 2
DEBUG:root:Processing batch 3
DEBUG:root:Processing batch 4
DEBUG:root:Processing batch 5
DEBUG:root:Processing batch 6
DEBUG:root:Processing batch 7
DEBUG:root:Processing batch 8
DEBUG:root:Processing batch 9
DEBUG:root:Processing batch 10
DEBUG:root:Processing batch 11
DEBUG:root:Processing batch 12
DEBUG:root:Processing batch 13
DEBUG:root:Processing batch 14
DEBUG:root:Processing batch 15
DEBUG:root:Processing batch 16
DEBUG:root:Processing batch 17
DEBUG:root:Processing batch 18
DEBUG:root:Processing batch 19
DEBUG:root:Processing batch 20
DEBUG:root:Processing batch 21
DEBUG:root:Processing batch 22
DEBUG:root:Processing batch 23
DEBUG:root:Processing batch 24
DEBUG:root:Processing batch 25
DEBUG:root:Processing batch 26
DEBUG:root:Processing batch 27
DEBUG:root:Processing batch 28
DEBUG:root:Processing batch 29


In [None]:
model.load_state_dict(torch.load("tut3.1-model.pt"))

test_loss = evaluate_fn(model, test_data_loader, criterion, device)

# print(f"| Test Loss: {test_loss:.3f} | Test PPL: {np.exp(test_loss):7.3f} |")


In [None]:
model.load_state_dict(torch.load("tut3.1-model.pt"))

test_loss = evaluate_fn(model, test_data_loader, criterion, device)