In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer

import numpy as np
import random
from tqdm import tqdm

In [2]:
dataset = load_dataset("bentrevett/multi30k")
dataset

Downloading readme:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.60M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/164k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/156k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['en', 'de'],
        num_rows: 29000
    })
    validation: Dataset({
        features: ['en', 'de'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['en', 'de'],
        num_rows: 1000
    })
})

In [3]:
train_dataset = Dataset.from_dict(dataset["train"][0:28000])
validation_dataset = Dataset.from_dict(dataset["train"][28000:])
test_dataset = dataset["test"]

In [4]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
vocab = tokenizer.get_vocab()
vocab_size = len(vocab)

In [5]:
def tokenize(tokenizer, batch):
    src_out = tokenizer(batch["en"], padding=True, truncation=True)
    tgt_out = tokenizer(batch["de"], padding=True, truncation=True)
    
    return {
        "src_input_ids": src_out["input_ids"],
        "src_attention_mask": src_out["attention_mask"],
        "tgt_input_ids": tgt_out["input_ids"],
        "tgt_attention_mask": tgt_out["attention_mask"]
    }

In [6]:
train_dataset_tokenized = train_dataset.map(lambda batch: tokenize(tokenizer, batch), batched=True, batch_size=None)
validation_dataset_tokenized = validation_dataset.map(lambda batch: tokenize(tokenizer, batch), batched=True, batch_size=None)
test_dataset_tokenized = test_dataset.map(lambda batch: tokenize(tokenizer, batch), batched=True, batch_size=None)

Map:   0%|          | 0/28000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [7]:
def collate_fn(batch):
    inputs = torch.stack([torch.tensor([x["src_input_ids"] for x in batch])])
    labels = torch.stack([torch.tensor([x["tgt_input_ids"] for x in batch])])
    return inputs, labels 

In [8]:
BATCH_SIZE=128

train_dataloader = DataLoader(train_dataset_tokenized, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
validation_dataloader = DataLoader(validation_dataset_tokenized, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset_tokenized, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [9]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_size=hidden_size, num_layers=1, bidirectional=True)
        self.fc_hidden = nn.Linear(2 * hidden_size, hidden_size)
        self.fc_cell = nn.Linear(2 * hidden_size, hidden_size)
        
    def forward(self, x):
        embedding = self.embedding(x)
        encoder_states, (hidden, cell) = self.rnn(embedding)
        hidden = F.relu(self.fc_hidden(torch.cat((hidden[0], hidden[1]), dim=1)))
        cell = F.relu(self.fc_cell(torch.cat((cell[0], cell[1]), dim=1)))
        
        return encoder_states, hidden, cell

In [10]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        
        self.fc = nn.Linear(2 * hidden_size, hidden_size)
        
    def forward(self, encoder_states, hidden):
        encoder_states = self.fc(encoder_states)
        encoder_states = encoder_states.permute(1, 0, 2)
        hidden = hidden.permute(1, 2, 0)
        
        attention_scores = torch.bmm(encoder_states, hidden)
        attention_weights = F.softmax(attention_scores, dim=1)
        
        context_vector = torch.bmm(attention_weights.permute(0, 2, 1), encoder_states)
        
        return context_vector, attention_weights

In [11]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.attention = Attention(hidden_size)
        self.rnn = nn.LSTM(embedding_dim + hidden_size, hidden_size, num_layers=1, bidirectional=False)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, encoder_states, hidden, cell):
        embedding = self.embedding(x)
        context_vector, _ = self.attention(encoder_states, hidden)
        context_vector = context_vector.permute(1, 0, 2)
        outputs, (hidden, cell) = self.rnn(torch.cat((embedding, context_vector), dim=2), (hidden, cell))
        outputs = self.fc(outputs)
        return outputs, hidden, cell

In [12]:
class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, device):
        super().__init__()
        
        self.encoder = Encoder(vocab_size, embedding_dim, hidden_size)
        self.decoder = Decoder(vocab_size, embedding_dim, hidden_size, output_size)
        
        self.vocab_size = vocab_size
        self.device = device
        
    def forward(self, x, labels, teacher_p):
        batch_size = x.shape[1]
        seq_length = labels.shape[0]
        
        encoder_states, hidden, cell = self.encoder(x)
        hidden = hidden.unsqueeze(0)
        cell = cell.unsqueeze(0)
        
        outputs = torch.zeros((seq_length, batch_size, self.vocab_size)).to(self.device)
        
        prev_token = labels[0].unsqueeze(0)
        for t in range(1, seq_length):
            preds, hidden, cell = self.decoder(prev_token, encoder_states, hidden, cell)
            outputs[t] = preds
            
            prev_token = labels[t].unsqueeze(0) if random.random() < teacher_p else preds.argmax(2)
            
        return outputs


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}.")

Using device cpu.


In [14]:
model = Seq2Seq(vocab_size, 256, 512, vocab_size, device).to(device)

In [15]:
class Trainer:
    def __init__(self, model, device, epochs, optimizer, criterion, train_dataloader, train_steps, validation_dataloader, validation_steps, clip):
        self.model = model
        self.device = device
        self.epochs = epochs
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_dataloader = train_dataloader
        self.train_steps = train_steps
        self.validation_dataloader = validation_dataloader
        self.validation_steps = validation_steps
        self.clip = clip
        
        self.loss = {"train": [], "val": []}
        
    def train(self):
        for epoch in range(self.epochs):
            self.train_epoch()
            self.validate_epoch()
            print(f"Epoch: {epoch + 1} train loss: {self.loss['train'][-1]} validation loss: {self.loss['val'][-1]}")
        
    def train_epoch(self):
        self.model.train()
        running_loss = []
        
        for i, batch_data in tqdm(enumerate(self.train_dataloader, 1)):
            self.optimizer.zero_grad()
            
            inputs = batch_data[0].squeeze(0).permute(1, 0).to(self.device)
            labels = batch_data[1].squeeze(0).permute(1, 0).to(self.device)
            
            outputs = self.model(inputs, labels, 0.5)
            output_size = outputs.shape[2]
            outputs = outputs[1:].view(-1, output_size)
            labels = labels[1:].reshape(-1)
            
            loss = self.criterion(outputs, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
            optimizer.step()
            
            running_loss.append(loss.item())
            
            if i == self.train_steps:
                break
        
        self.loss["train"].append(np.mean(running_loss))
        
    def validate_epoch(self):
        self.model.eval()
        running_loss = []
        
        with torch.no_grad():
            for i, batch_data in tqdm(enumerate(self.validation_dataloader, 1)):
                inputs = batch_data[0].squeeze(0).permute(1, 0).to(self.device)
                labels = batch_data[1].squeeze(0).permute(1, 0).to(self.device)

                outputs = self.model(inputs, labels, 0.5)
                output_size = outputs.shape[2]
                outputs = outputs[1:].view(-1, output_size)
                labels = labels[1:].reshape(-1)

                loss = self.criterion(outputs, labels)

                running_loss.append(loss.item())

                if i == self.validation_steps:
                    break

            self.loss["val"].append(np.mean(running_loss))
        

In [16]:
EPOCHS = 10
TRAIN_STEPS = 1
VALIDATION_STEPS = 1
CLIP = 1

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

trainer = Trainer(model, device, EPOCHS, optimizer, criterion, train_dataloader, TRAIN_STEPS, validation_dataloader, VALIDATION_STEPS, CLIP)

In [17]:
torch.cuda.empty_cache()

In [18]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(30522, 256)
    (rnn): LSTM(256, 512, bidirectional=True)
    (fc_hidden): Linear(in_features=1024, out_features=512, bias=True)
    (fc_cell): Linear(in_features=1024, out_features=512, bias=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(30522, 256)
    (attention): Attention(
      (fc): Linear(in_features=1024, out_features=512, bias=True)
    )
    (rnn): LSTM(768, 512)
    (fc): Linear(in_features=512, out_features=30522, bias=True)
  )
)

In [None]:
trainer.train()

4it [04:05, 62.05s/it]

In [None]:
torch.save(model.state_dict(), 'model_params.pth')