In [1]:
%%capture
!pip install datasets
!pip install tokenizers
!pip install mlflow

In [2]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tqdm import tqdm
import math
import mlflow
import mlflow.pytorch

from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [3]:
config = {
    'datasource': 'iwslt2017',
    'input_lang': 'en',
    'output_lang': 'fr',
    'tokenizer_file': 'tokenizer_{0}.json',
    'seq_len': 150,
    'batch_size': 64,
    'embed_size': 512,
    'dropout': 0.1,
    'num_heads': 8,
    'hidden_size': 2048,
    'num_layers': 6,
    'num_epochs': 10,
    'learning_rate': 3e-4,
    'model_folder': 'model',
    'model_filename': '',
    'preload': 'latest',
    'model_basename': 'lang2lang',
}

# Import and configure dataset

In [4]:
dataset = load_dataset(config['datasource'], config['datasource']+'-'+config['input_lang']+'-'+config['output_lang'])

Downloading builder script:   0%|          | 0.00/8.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/18.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/232825 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8597 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/890 [00:00<?, ? examples/s]

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 232825
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 8597
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 890
    })
})

In [6]:
dataset['train']

Dataset({
    features: ['translation'],
    num_rows: 232825
})

In [7]:
dataset['train'][0]

{'translation': {'en': "Thank you so much, Chris. And it's truly a great honor to have the opportunity to come to this stage twice; I'm extremely grateful.",
  'fr': "Merci beaucoup, Chris. C'est vraiment un honneur de pouvoir venir sur cette scène une deuxième fois. Je suis très reconnaissant."}}

In [8]:
def yield_sample(dataset, lang):
    for sample in dataset:
        yield sample['translation'][lang]

def build_tokenizer(config, dataset, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if Path.exists(tokenizer_path):
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    else:
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(yield_sample(dataset, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    return tokenizer

In [9]:
tokenizer = build_tokenizer(config, dataset['train'], config['input_lang'])

In [10]:
ids = tokenizer.encode("Hello world").ids
ids

[3574, 94]

In [11]:
tokenizer.decode(ids)

'Hello world'

In [12]:
def find_max_min_seq(dataset):
  max_input_length = 0
  max_output_length = 0
  for sample in dataset:
      in_ids = tokenizer.encode(sample['translation'][config['input_lang']]).ids
      out_ids = tokenizer.encode(sample['translation'][config['output_lang']]).ids
      max_input_length = max(max_input_length, len(in_ids))
      max_output_length = max(max_output_length, len(out_ids))

  print("Max input sequence length: ", max_input_length)
  print("Max output sequence length: ", max_output_length)

find_max_min_seq(dataset['train'])
find_max_min_seq(dataset['validation'])

Max input sequence length:  133
Max output sequence length:  122
Max input sequence length:  90
Max output sequence length:  97


In [13]:
class TranslationDataset(Dataset):
    def __init__(self, raw_dataset, input_lang, output_lang, input_tokenizer, output_tokenizer, seq_len):
        self.raw_dataset = raw_dataset
        self.input_lang = input_lang
        self.output_lang = output_lang
        self.input_tokenizer = input_tokenizer
        self.output_tokenizer = output_tokenizer
        self.seq_len = seq_len

        self.sos_token = torch.tensor([input_tokenizer.token_to_id("[SOS]")], dtype=torch.long)
        self.eos_token = torch.tensor([input_tokenizer.token_to_id("[EOS]")], dtype=torch.long)
        self.pad_token = torch.tensor([input_tokenizer.token_to_id("[PAD]")], dtype=torch.long)

    def __getitem__(self, idx):
        sample = self.raw_dataset[idx]
        input_ids = self.input_tokenizer.encode(sample['translation'][self.input_lang]).ids
        output_ids = self.output_tokenizer.encode(sample['translation'][self.output_lang]).ids

        if len(input_ids) > self.seq_len:
            input_ids = input_ids[:self.seq_len]

        num_of_padding_tokens_input = self.seq_len - len(input_ids) - 2 # -2 for sos and eos
        num_of_padding_tokens_output = self.seq_len - len(output_ids) - 1 # -1 for sos

        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(input_ids, dtype=torch.long),
                self.eos_token,
                torch.tensor([self.pad_token] * num_of_padding_tokens_input, dtype=torch.long),
            ],
            dim=0,
        )

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(output_ids, dtype=torch.long),
                torch.tensor([self.pad_token] * num_of_padding_tokens_output, dtype=torch.long),
            ],
            dim=0,
        )

        label = torch.cat(
            [
                torch.tensor(output_ids, dtype=torch.long),
                self.eos_token,
                torch.tensor([self.pad_token] * num_of_padding_tokens_output, dtype=torch.long),
            ],
            dim=0,
        )

        padding_mask = (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int()
        lookahead_mask = (decoder_input != self.pad_token).unsqueeze(0).int() & self.causal_mask(decoder_input.size(0))

        return encoder_input, decoder_input, label, padding_mask, lookahead_mask


    def __len__(self):
        return len(self.raw_dataset)

    def causal_mask(self, size):
      mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
      return mask == 0

In [14]:
data = TranslationDataset(dataset['train'], config['input_lang'], config['output_lang'], tokenizer, tokenizer, config['seq_len'])

In [15]:
encoder_input, decoder_input, label, padding_mask, lookahead_mask = data[0]

In [16]:
(decoder_input != data.pad_token).unsqueeze(0).int()

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], dtype=torch.int32)

In [17]:
(torch.tril(torch.ones((1, 150, 150)), diagonal=0) == 1).int()

tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]]], dtype=torch.int32)

In [18]:
(decoder_input != data.pad_token).unsqueeze(0).int() & (torch.tril(torch.ones((1, 150, 150)), diagonal=0) == 1).int()

tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]], dtype=torch.int32)

In [19]:
decoder_input.shape

torch.Size([150])

In [20]:
lookahead_mask.shape

torch.Size([1, 150, 150])

In [21]:
encoder_input.shape

torch.Size([150])

In [22]:
padding_mask.shape

torch.Size([1, 1, 150])

In [23]:
def build_dataloader_and_tokenizers(config):
    raw_dataset = load_dataset(config['datasource'], config['datasource']+'-'+config['input_lang']+'-'+config['output_lang'])

    input_tokenizer = build_tokenizer(config, raw_dataset['train'], config['input_lang'])
    output_tokenizer = build_tokenizer(config, raw_dataset['train'], config['output_lang'])

    train = TranslationDataset(raw_dataset['train'], config['input_lang'], config['output_lang'], input_tokenizer, output_tokenizer, config['seq_len'])
    validation = TranslationDataset(raw_dataset['validation'], config['input_lang'], config['output_lang'], input_tokenizer, output_tokenizer, config['seq_len'])

    train_dataloader = DataLoader(train, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(validation, batch_size=config['batch_size'], shuffle=True)

    return train_dataloader, val_dataloader, input_tokenizer, output_tokenizer

In [24]:
train_dataloader, val_dataloader, input_tokenizer, output_tokenizer = build_dataloader_and_tokenizers(config)

In [25]:
encoder_input, decoder_input, label, padding_mask, lookahead_mask = train_dataloader.__iter__().__next__()

In [26]:
encoder_input.shape # (batch_size, seq_len)

torch.Size([64, 150])

In [27]:
decoder_input.shape # (batch_size, seq_len)

torch.Size([64, 150])

In [28]:
label.shape

torch.Size([64, 150])

In [29]:
padding_mask.shape

torch.Size([64, 1, 1, 150])

In [30]:
lookahead_mask.shape

torch.Size([64, 1, 150, 150])

# Build model

In [31]:
train_dataloader, val_dataloader, input_tokenizer, output_tokenizer = build_dataloader_and_tokenizers(config)
encoder_input, decoder_input, label, padding_mask, lookahead_mask = train_dataloader.__iter__().__next__()

In [32]:
input_vocab_size = input_tokenizer.get_vocab_size()
output_vocab_size = output_tokenizer.get_vocab_size()

In [33]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embed_size = embed_size
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, embed_size)
        return self.embedding(x) * math.sqrt(self.embed_size)

In [34]:
embeds = InputEmbeddings(input_vocab_size, config['embed_size'])
x = embeds(encoder_input)

In [35]:
x.shape

torch.Size([64, 150, 512])

In [36]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, seq_len, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        position = torch.arange(seq_len).unsqueeze(1) # (seq_len, 1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0) / embed_size)) # (embed_size / 2)
        pe = torch.zeros(1, seq_len, embed_size) # (1, seq_len, embed_size)

        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # (batch, seq_len, embed_size)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [37]:
pos_embeds = PositionalEncoding(config['embed_size'], config['seq_len'], config['dropout'])
x = pos_embeds(x)

In [38]:
x.shape # (batch_size, seq_len, embed_size)

torch.Size([64, 150, 512])

In [39]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads, dropout):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

        self.Wq = nn.Linear(embed_size, embed_size)
        self.Wk = nn.Linear(embed_size, embed_size)
        self.Wv = nn.Linear(embed_size, embed_size)
        self.Wo = nn.Linear(embed_size, embed_size)

        self.head_size = embed_size // num_heads

    def attention(self, q, k, v, mask):
        # q @ k then scale
        # (batch, num_heads, seq_len, head_size) --> (batch, num_heads, seq_len, seq_len)
        attention = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_size)

        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)

        attention = F.softmax(attention, dim=-1)

        # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_size) --> (batch, num_heads, seq_len, head_size)
        return self.dropout(attention @ v)

    def forward(self, q, k, v, mask):
        # (batch, seq_len, embed_size) --> (batch, seq_len, embed_size)
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        # (batch, seq_len, embed_size) --> (batch, seq_len, num_heads, head_size) --> (batch, num_heads, seq_len, head_size)
        q = q.reshape(q.shape[0], q.shape[1], self.num_heads, self.head_size).transpose(1, 2)
        k = k.reshape(k.shape[0], k.shape[1], self.num_heads, self.head_size).transpose(1, 2)
        v = v.reshape(v.shape[0], v.shape[1], self.num_heads, self.head_size).transpose(1, 2)

        # (batch, num_heads, seq_len, head_size)
        x = self.attention(q, k, v, mask)

        # (batch, num_heads, seq_len, head_size) --> (batch, seq_len, embed_size)
        x = x.transpose(1, 2).contiguous().reshape(x.shape[0], x.shape[2], self.head_size*self.num_heads)

        return self.Wo(x)

In [40]:
multi_head_attention = MultiHeadAttention(config['embed_size'], config['num_heads'], config['dropout'])
x = multi_head_attention(x, x, x, padding_mask)

In [41]:
x.shape

torch.Size([64, 150, 512])

In [42]:
layer_norm = nn.LayerNorm(config['embed_size'])
x = layer_norm(x)

In [43]:
x.shape

torch.Size([64, 150, 512])

In [44]:
class FeedForwardBlock(nn.Module):
    def __init__(self, embed_size, hidden_size, dropout):
        super().__init__()
        self.lin_1 = nn.Linear(embed_size, hidden_size)
        self.lin_2 = nn.Linear(hidden_size, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.lin_2(self.dropout(torch.relu(self.lin_1(x))))

In [45]:
ffn = FeedForwardBlock(config['embed_size'], config['hidden_size'], config['dropout'])
x = ffn(x)

In [46]:
x.shape

torch.Size([64, 150, 512])

In [47]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_size, dropout):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(embed_size, num_heads, dropout)
        self.feed_forward = FeedForwardBlock(embed_size, hidden_size, dropout)
        self.layer_norm = nn.ModuleList([nn.LayerNorm(embed_size) for _ in range(2)])

    def forward(self, x, mask):
        x = self.layer_norm[0](x + self.multi_head_attention(x, x, x, mask))
        x = self.layer_norm[1](x + self.feed_forward(x))
        return x

In [48]:
encoder_layer = EncoderLayer(config['embed_size'], config['num_heads'], config['hidden_size'], config['dropout'])
x = encoder_layer(x, padding_mask)

In [49]:
x.shape

torch.Size([64, 150, 512])

In [50]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_size, dropout):
        super().__init__()
        self.masked_multi_head_attention = MultiHeadAttention(embed_size, num_heads, dropout)
        self.multi_head_attention = MultiHeadAttention(embed_size, num_heads, dropout)
        self.feed_forward = FeedForwardBlock(embed_size, hidden_size, dropout)
        self.layer_norm = nn.ModuleList([nn.LayerNorm(embed_size) for _ in range(3)])

    def forward(self, x, encoder_output, padding_mask, lookahead_mask):
        x = self.layer_norm[0](x + self.masked_multi_head_attention(x, x, x, lookahead_mask))
        x = self.layer_norm[1](x + self.multi_head_attention(x, encoder_output, encoder_output, padding_mask))
        x = self.layer_norm[2](x + self.feed_forward(x))
        return x


In [51]:
decoder_layer = DecoderLayer(config['embed_size'], config['num_heads'], config['hidden_size'], config['dropout'])
x = decoder_layer(x, x, padding_mask, lookahead_mask)

In [52]:
x.shape

torch.Size([64, 150, 512])

In [53]:
class Transformer(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, embed_size, seq_len, num_heads, hidden_size, dropout, num_layers):
        super().__init__()
        self.input_embedding = nn.Embedding(input_vocab_size, embed_size)
        self.output_embedding = nn.Embedding(output_vocab_size, embed_size)
        self.input_positional_encoding = PositionalEncoding(embed_size, seq_len, dropout)
        self.output_positional_encoding = PositionalEncoding(embed_size, seq_len, dropout)
        self.encoder = nn.ModuleList([EncoderLayer(embed_size, num_heads, hidden_size, dropout) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(embed_size, num_heads, hidden_size, dropout) for _ in range(num_layers)])
        self.linear = nn.Linear(embed_size, output_vocab_size)

    def encode(self, x, padding_mask):
        x = self.input_embedding(x)
        x = self.input_positional_encoding(x)

        for encoder_layer in self.encoder:
            x = encoder_layer(x, padding_mask)

        return x

    def decode(self, x, encoder_output, padding_mask, lookahead_mask):
        x = self.output_embedding(x)
        x = self.output_positional_encoding(x)
        for decoder_layer in self.decoder:
            x = decoder_layer(x, encoder_output, padding_mask, lookahead_mask)

        return x

    def predict(self, x):
        return self.linear(x)

In [54]:
#train_dataloader, val_dataloader, input_tokenizer, output_tokenizer = build_dataloader_and_tokenizers(config)
#encoder_input, decoder_input, label, padding_mask, lookahead_mask = train_dataloader.__iter__().__next__()

In [55]:
transformer = Transformer(input_tokenizer.get_vocab_size(), output_tokenizer.get_vocab_size(), config['embed_size'], config['seq_len'], config['num_heads'], config['hidden_size'], config['dropout'], config['num_layers'])

In [56]:
encoder_output = transformer.encode(encoder_input, padding_mask)

In [57]:
encoder_output.shape

torch.Size([64, 150, 512])

In [58]:
decoder_output = transformer.decode(decoder_input, encoder_output, padding_mask, lookahead_mask)

In [59]:
decoder_output.shape

torch.Size([64, 150, 512])

In [60]:
out = transformer.predict(decoder_output[:,-1])

In [61]:
out.shape

torch.Size([64, 30000])

In [62]:
_, next_word = torch.max(out, dim=1)

In [63]:
next_word.shape

torch.Size([64])

In [64]:
next_word

tensor([ 7835,  2630,  7835,  7835,  7835,  7835,  7835,  7835,  7835,  7835,
         7835,  5858,  5858,  7835,  2630,  7835,  7835,  7835,  7835,  7835,
         5858,  7835,  7835,  7835, 29732,  7835,  6035,  7835, 25608,  7835,
         6035,  7835,  7835,  7835,  7835, 12020,  9819,  7835,  7835,  6035,
         2630,  7835,  7835,  7835, 20777,  6035,  6035,  7835,  7835,  7835,
         7835,  7835, 19831,  7835,  3638,  7835, 29732,  2630, 23136,  5858,
         7835,  5793,  7835,  7835])

In [65]:
output_tokenizer.decode(next_word.tolist())

'instable furent instable instable instable instable instable instable instable instable instable moustiquaires moustiquaires instable furent instable instable instable instable instable moustiquaires instable instable instable inversés instable franchir instable désarroi instable franchir instable instable instable instable Kaboul sévères instable instable franchir furent instable instable instable braves franchir franchir instable instable instable instable instable ping instable Ghana instable inversés furent gaspillons moustiquaires instable stockage instable instable'

# Training

In [66]:
# code from https://github.com/hkproj/pytorch-transformer/blob/main/config.py
def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [67]:
def train(config):
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device: ", device)

    mlflow.start_run()

    train_dataloader, val_dataloader, input_tokenizer, output_tokenizer = build_dataloader_and_tokenizers(config)

    model = Transformer(input_tokenizer.get_vocab_size(), output_tokenizer.get_vocab_size(), config['embed_size'], config['seq_len'], config['num_heads'], config['hidden_size'], config['dropout'], config['num_layers'])
    model.to(device)

    # apply xavier initialization
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)

    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])
    loss_fn = nn.CrossEntropyLoss(ignore_index=input_tokenizer.token_to_id('[PAD]')).to(device)

    initial_epoch = 0
    global_step = 0
    best_epoch = 0
    prev_val_loss = float('inf')
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    print(model_filename)
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
        best_epoch = state['best_epoch']
        prev_val_loss = state['prev_val_loss']
    else:
        print('No model to preload, starting from scratch')

    mlflow.log_param("learning_rate", config['learning_rate'])
    mlflow.log_param("batch_size", config['batch_size'])
    mlflow.log_param("epochs", config['num_epochs'])
    mlflow.log_param("embed_size", config['embed_size'])
    mlflow.log_param("hidden_size", config['hidden_size'])
    mlflow.log_param("num_heads", config['num_heads'])
    mlflow.log_param("num_layers", config['num_layers'])
    mlflow.log_param("dropout", config['dropout'])
    mlflow.log_param("seq_len", config['seq_len'])
    mlflow.log_param("datasource", config['datasource'])

    mlflow.pytorch.log_model(model, "models")

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        train_loss = 0
        idx = 0
        best_epoch = 0
        for encoder_input, decoder_input, label, padding_mask, lookahead_mask in tqdm(train_dataloader, desc=f"Processing Epoch {epoch} for training"):
            encoder_input = encoder_input.to(device) # (batch_size, seq_len)
            decoder_input = decoder_input.to(device) # (batch_size, seq_len)
            label = label.to(device) # (batch_size, seq_len)
            padding_mask = padding_mask.to(device) # (batch_size, 1, 1, seq_len)
            lookahead_mask = lookahead_mask.to(device) # (batch_size, 1, seq_len, seq_len)

            encoder_output = model.encode(encoder_input, padding_mask)
            decoder_output = model.decode(decoder_input, encoder_output, padding_mask, lookahead_mask)
            output = model.predict(decoder_output)

            loss = loss_fn(output.reshape(-1, output.shape[-1]), label.reshape(-1))
            train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1

            idx += 1
            if idx % 100 == 0:
              print("Train loss: ", train_loss / idx)
              mlflow.log_metric("Train loss", train_loss / idx)


        print(f"Epoch Training Loss { train_loss / idx}")
        mlflow.log_metric("Epoch Training Loss", train_loss / idx)

        # validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            idx = 0
            for encoder_input, decoder_input, label, padding_mask, lookahead_mask in tqdm(val_dataloader, desc=f"Processing Epoch {epoch} for validation"):
                encoder_input = encoder_input.to(device) # (batch_size, seq_len)
                decoder_input = decoder_input.to(device) # (batch_size, seq_len)
                label = label.to(device) # (batch_size, seq_len)
                padding_mask = padding_mask.to(device) # (batch_size, 1, 1, seq_len)
                lookahead_mask = lookahead_mask.to(device) # (batch_size, 1, seq_len, seq_len)

                encoder_output = model.encode(encoder_input, padding_mask)
                decoder_output = model.decode(decoder_input, encoder_output, padding_mask, lookahead_mask)
                output = model.predict(decoder_output)

                loss = loss_fn(output.reshape(-1, output.shape[-1]), label.reshape(-1))
                val_loss += loss.item()

                idx += 1
                if idx % 100 == 0:
                  print("Validation loss: ", val_loss / idx)
                  mlflow.log_metric("Validation loss", val_loss / idx)

        print(f"Epoch Validation Loss { val_loss / idx}")
        mlflow.log_metric("Epoch Validation Loss", val_loss / idx)

        if val_loss / len(val_dataloader) < prev_val_loss:
            best_epoch = epoch
            prev_val_loss = val_loss / len(val_dataloader)

        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step,
            'best_epoch': best_epoch,
            'prev_val_loss': prev_val_loss
        }, model_filename)
        mlflow.end_run()


In [68]:
train(config)

Using device:  cuda
None
No model to preload, starting from scratch


Processing Epoch 0 for training:   3%|▎         | 100/3638 [01:56<1:07:38,  1.15s/it]

Train loss:  6.923396215438843


Processing Epoch 0 for training:   5%|▌         | 200/3638 [03:52<1:06:24,  1.16s/it]

Train loss:  6.672008047103882


Processing Epoch 0 for training:   8%|▊         | 300/3638 [05:47<1:04:33,  1.16s/it]

Train loss:  6.586062849362691


Processing Epoch 0 for training:  11%|█         | 400/3638 [07:43<1:02:23,  1.16s/it]

Train loss:  6.5459400594234465


Processing Epoch 0 for training:  14%|█▎        | 500/3638 [09:38<1:00:42,  1.16s/it]

Train loss:  6.521741988182068


Processing Epoch 0 for training:  16%|█▋        | 600/3638 [11:34<58:30,  1.16s/it]

Train loss:  6.502556925614675


Processing Epoch 0 for training:  19%|█▉        | 700/3638 [13:29<56:41,  1.16s/it]

Train loss:  6.4905577639171055


Processing Epoch 0 for training:  22%|██▏       | 800/3638 [15:25<54:31,  1.15s/it]

Train loss:  6.481595509648323


Processing Epoch 0 for training:  25%|██▍       | 900/3638 [17:20<52:52,  1.16s/it]

Train loss:  6.472768846617805


Processing Epoch 0 for training:  27%|██▋       | 1000/3638 [19:16<50:47,  1.16s/it]

Train loss:  6.465804897785187


Processing Epoch 0 for training:  30%|███       | 1100/3638 [21:11<48:56,  1.16s/it]

Train loss:  6.459820019982078


Processing Epoch 0 for training:  33%|███▎      | 1200/3638 [23:06<46:52,  1.15s/it]

Train loss:  6.455407412052154


Processing Epoch 0 for training:  36%|███▌      | 1300/3638 [25:02<44:59,  1.15s/it]

Train loss:  6.452039743203383


Processing Epoch 0 for training:  38%|███▊      | 1374/3638 [26:27<43:36,  1.16s/it]


KeyboardInterrupt: ignored