In [None]:
pip install -U datasets huggingface_hub fsspec

In [2]:
from datasets import get_dataset_config_names
get_dataset_config_names('opus_books')

README.md: 0.00B [00:00, ?B/s]

['ca-de',
 'ca-en',
 'ca-hu',
 'ca-nl',
 'de-en',
 'de-eo',
 'de-es',
 'de-fr',
 'de-hu',
 'de-it',
 'de-nl',
 'de-pt',
 'de-ru',
 'el-en',
 'el-es',
 'el-fr',
 'el-hu',
 'en-eo',
 'en-es',
 'en-fi',
 'en-fr',
 'en-hu',
 'en-it',
 'en-nl',
 'en-no',
 'en-pl',
 'en-pt',
 'en-ru',
 'en-sv',
 'eo-es',
 'eo-fr',
 'eo-hu',
 'eo-it',
 'eo-pt',
 'es-fi',
 'es-fr',
 'es-hu',
 'es-it',
 'es-nl',
 'es-no',
 'es-pt',
 'es-ru',
 'fi-fr',
 'fi-hu',
 'fi-no',
 'fi-pl',
 'fr-hu',
 'fr-it',
 'fr-nl',
 'fr-no',
 'fr-pl',
 'fr-pt',
 'fr-ru',
 'fr-sv',
 'hu-it',
 'hu-nl',
 'hu-no',
 'hu-pl',
 'hu-pt',
 'hu-ru',
 'it-nl',
 'it-pt',
 'it-ru',
 'it-sv']

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import math
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from tqdm import tqdm
import warnings
class InputEmbeddings(nn.Module):
  def __init__(self, embedding_dim, vocab_size):   # embedding_dim is d_model
    super().__init__()
    self.embedding_dim = embedding_dim
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, embedding_dim)

  def forward(self, x):
    return self.embedding(x) * math.sqrt(self.embedding_dim) # scale embeddings to match positional encoding scale
class PositionalEncoding(nn.Module):
  def __init__(self, embedding_dim, sequence_len, dropout):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.sequence_len = sequence_len
    self.dropout = nn.Dropout(dropout)

    PE = torch.zeros(sequence_len, embedding_dim)
    position = torch.arange(0, sequence_len, dtype = torch.float).unsqueeze(1)

    denominator_term = torch.exp(torch.arange(0, embedding_dim, step = 2).float() * (-math.log(10000.0) / embedding_dim))

    PE[:, 0::2] = torch.sin(position * denominator_term)
    PE[:, 1::2] = torch.cos(position * denominator_term)
    PE = PE.unsqueeze(0)

    self.register_buffer('PE', PE)

  def forward(self, x):
    x = x + (self.PE[:, :x.shape[1], :]).requires_grad_(False)
    return self.dropout(x)

class MultiHeadAttentionBlock(nn.Module):
  def __init__(self, embedding_dim, h, dropout):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.h = h

    assert embedding_dim % h == 0, "embedding_dim is not divisible by h"

    self.d_k = embedding_dim // h
    self.w_q = nn.Linear(embedding_dim, embedding_dim)
    self.w_k = nn.Linear(embedding_dim, embedding_dim)
    self.w_v = nn.Linear(embedding_dim, embedding_dim)
    self.w_o = nn.Linear(embedding_dim, embedding_dim)

    self.dropout = nn.Dropout(dropout)

  @staticmethod
  def attention(query, key, value, mask, dropout):
    d_k = query.shape[-1]
    attention_score = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
      attention_score.masked_fill_(mask == 0, -1e9)
    attention_score = attention_score.softmax(dim = -1)
    if dropout is not None:
      attention_score = dropout(attention_score)
    return (attention_score @ value), attention_score

  def forward(self, q, k, v, mask):
    query = self.w_q(q)
    key = self.w_k(k)
    value = self.w_v(v)

    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)

    x, self.attention_score = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
    x = x.transpose(2,1).contiguous().view(x.shape[0], -1, self.h * self.d_k)
    return self.w_o(x)

class LayerNormalization(nn.Module):
  def __init__(self, eps = 10**-6):
    super().__init__()
    self.eps = eps

    self.gamma = nn.Parameter(torch.ones(1))
    self.beta = nn.Parameter(torch.zeros(1))

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    std = x.std(dim=-1, keepdim=True)
    return (self.gamma * ((x - mean) / (std + self.eps))) + self.beta

class FeedForwardBlock(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super().__init__()
    self.linear_1 = nn.Linear(d_model, d_ff)
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class ResidualConnection(nn.Module):
  def __init__(self, dropout):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNormalization()

  def forward(self, x, sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

class EncoderBlock(nn.Module):
  def __init__(self, self_attention_block, feed_forward_block, dropout):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block

    self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

  def forward(self, x, src_mask):
      x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
      x = self.residual_connections[1](x, self.feed_forward_block)
      return x

class Encoder(nn.Module):
  def __init__(self, layers):
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization()

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

class DecoderBlock(nn.Module):
  def __init__(self, self_attention_block, cross_attention_block, feed_forward_block, dropout):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block

    self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

  def forward(self, x, encoder_output, src_mask, target_mask):
    x = self.residual_connections[0](x, lambda x:self.self_attention_block(x,x,x,target_mask))
    x = self.residual_connections[1](x, lambda x:self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)
    return x

class Decoder(nn.Module):
  def __init__(self, layers):
    super().__init__()
    self.layers = layers

    self.norm = LayerNormalization()

  def forward(self, x, encoder_output, src_mask, target_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, target_mask)
    return self.norm(x)

class ProjectionLayer(nn.Module):
  def __init__(self, embedding_dim, vocab_size):
    super().__init__()

    self.projection = nn.Linear(embedding_dim, vocab_size)

  def forward(self,x):
    return torch.log_softmax(self.projection(x), dim=-1)

class Transformer(nn.Module):
  def __init__(self, encoder, decoder, src_embedding, target_embedding, src_positional_encoding, target_positional_encoding, projection_layer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embedding = src_embedding
    self.target_embedding = target_embedding
    self.src_positional_encoding = src_positional_encoding
    self.target_positional_encoding = target_positional_encoding
    self.projection_layer = projection_layer

  def encode(self, src, src_mask):
    src = self.src_embedding(src)
    src = self.src_positional_encoding(src)
    return self.encoder(src, src_mask)

  def decode(self, encoder_output, src_mask, target, target_mask):
    target = self.target_embedding(target)
    target = self.target_positional_encoding(target)
    return self.decoder(target, encoder_output, src_mask, target_mask)

  def project(self, x):
    return self.projection_layer(x)
def build_transformer(src_vocab_size, target_vocab_size, src_seq_len, target_seq_len, embedding_dim = 512, N = 6, h = 8, dropout = 0.1, d_ff = 2048):
  src_embedding = InputEmbeddings(embedding_dim, src_vocab_size)
  target_embedding = InputEmbeddings(embedding_dim, target_vocab_size)
  src_positional_encoding = PositionalEncoding(embedding_dim, src_seq_len, dropout)
  target_positional_encoding = PositionalEncoding(embedding_dim, target_seq_len, dropout)

  encoder_blocks = []
  for _ in range(N):
    encoder_self_attention_block = MultiHeadAttentionBlock(embedding_dim, h, dropout)
    feed_forward_block = FeedForwardBlock(embedding_dim, d_ff, dropout)
    encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
    encoder_blocks.append(encoder_block)

  decoder_blocks = []
  for _ in range(N):
    decoder_self_attention_block = MultiHeadAttentionBlock(embedding_dim, h, dropout)
    decoder_cross_attention_block = MultiHeadAttentionBlock(embedding_dim, h, dropout)
    feed_forward_block = FeedForwardBlock(embedding_dim, d_ff, dropout)
    decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
    decoder_blocks.append(decoder_block)

  encoder = Encoder(nn.ModuleList(encoder_blocks))
  decoder = Decoder(nn.ModuleList(decoder_blocks))

  projection_layer = ProjectionLayer(embedding_dim, target_vocab_size)
  transformer = Transformer(encoder, decoder, src_embedding, target_embedding, src_positional_encoding, target_positional_encoding, projection_layer)

  for p in transformer.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform_(p)

  return transformer

def get_all_sentences(ds, lang):
    for pair in ds:
        yield pair['translation'][lang]

class BilingualDataset(Dataset):
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
        super().__init__()

        self.seq_len = seq_len
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index) :
        src_target_pair = self.ds[index]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 

        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError('Sentence is too long')

        encoder_input = torch.cat(
            [
            self.sos_token, 
            torch.tensor(enc_input_tokens, dtype = torch.int64), 
            self.eos_token, 
            torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype = torch.int64) 
            ]
        )
        decoder_input = torch.cat(
            [
                self.sos_token, 
                torch.tensor(dec_input_tokens, dtype = torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype = torch.int64) 
            ]

        )

        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype = torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype = torch.int64) 

            ]
        )

        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            'encoder_input': encoder_input,
            'decoder_input': decoder_input,
            'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & casual_mask(decoder_input.size(0)),
            'label': label,
            'src_text': src_text,
            'tgt_text': tgt_text
        }

def get_ds(config):
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split = 'train')

    tokenizer_src = build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = build_tokenizer(config, ds_raw, config['lang_tgt'])

    train_ds_size = int(0.9 * len(ds_raw)) 
    val_ds_size = len(ds_raw) - train_ds_size 
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) 
    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    max_len_src = 0
    max_len_tgt = 0
    for pair in ds_raw:
        src_ids = tokenizer_src.encode(pair['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_src.encode(pair['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')

    train_dataloader = DataLoader(train_ds, batch_size = config['batch_size'], shuffle = True)
    val_dataloader = DataLoader(val_ds, batch_size = 1, shuffle = True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt 

def casual_mask(size):
  
        mask = torch.triu(torch.ones(1, size, size), diagonal = 1).type(torch.int)
        return mask == 0

def build_tokenizer(config, ds, lang):

    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token = '[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()

        trainer = WordLevelTrainer(special_tokens = ["[UNK]", "[PAD]",
                                                     "[SOS]", "[EOS]"], min_frequency = 2)

        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer = trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
  
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(source).to(device)

    while True:
        if decoder_input.size(1) == max_len:
            break

        decoder_mask = casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        prob = model.project(out[:, -1])

        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1,1). type_as(source).fill_(next_word.item()).to(device)], dim=1)

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_state, writer, num_examples=2):
    model.eval()
    count = 0 

    console_width = 80 
    with torch.no_grad(): 
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)

            assert encoder_input.size(0) ==  1, 'Batch size must be 1 for validation.'
            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch['src_text'][0]
            target_text = batch['tgt_text'][0] 
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) 

            print_msg('-'*console_width)
            print_msg(f'SOURCE: {source_text}')
            print_msg(f'TARGET: {target_text}')
            print_msg(f'PREDICTED: {model_out_text}')

            if count == num_examples:
                break

def get_model(config, vocab_src_len, vocab_tgt_len):

    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

def get_config():
    return{
        'batch_size': 8,
        'num_epochs': 2,
        'lr': 10**-4,
        'seq_len': 350,
        'd_model': 512, 
        'lang_src': 'en',
        'lang_tgt': 'it',
        'model_folder': 'weights',
        'model_basename': 'tmodel_',
        'preload': None,
        'tokenizer_file': 'tokenizer_{0}.json',
        'experiment_name': 'runs/tmodel'
    }

def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder'] 
    model_basename = config['model_basename'] 
    model_filename = f"{model_basename}{epoch}.pt" 
    return str(Path('.')/ model_folder/ model_filename)

def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device {device}")

    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)

    model = get_model(config,tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps = 1e-9)

    initial_epoch = 0
    global_step = 0

    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename) 

        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'), label_smoothing = 0.1).to(device)


    for epoch in range(initial_epoch, config['num_epochs']):

        batch_iterator = tqdm(train_dataloader, desc = f'Processing epoch {epoch:02d}')

        for batch in batch_iterator:
            model.train() 

            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            label = batch['label'].to(device)

            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))

            batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})

            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            loss.backward()

            optimizer.step()

            optimizer.zero_grad()

            global_step += 1 

        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        torch.save({
            'epoch': epoch, 
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 
            'global_step': global_step 
        }, model_filename)

In [5]:
if __name__ == '__main__':
    warnings.filterwarnings('ignore') 
    config = get_config() 
    train_model(config) 

Using device cuda


train-00000-of-00001.parquet:   0%|          | 0.00/5.73M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/32332 [00:00<?, ? examples/s]

Max length of source sentence: 309
Max length of target sentence: 274


Processing epoch 00: 100%|██████████| 3638/3638 [15:28<00:00,  3.92it/s, loss=4.880]


--------------------------------------------------------------------------------
SOURCE: I could not unlove him, because I felt sure he would soon marry this very lady--because I read daily in her a proud security in his intentions respecting her--because I witnessed hourly in him a style of courtship which, if careless and choosing rather to be sought than to seek, was yet, in its very carelessness, captivating, and in its very pride, irresistible.
TARGET: Non potevo cessar di amarlo perché capivo che avrebbe sposato presto quella ragazza; perché leggevo nel contegno della signorina Ingram l'altera sicurezza del trionfo, perché infine a ogni istante scoprivo nel signor Rochester una specie di cortesia, che nonostante fosse imposta, più che data, era irresistibile nella sua noncuranza e nel suo orgoglio.
PREDICTED: " Ma non mi , ma non mi , e non mi , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e a un ' altra .
---------------------------------------------------

Processing epoch 01: 100%|██████████| 3638/3638 [15:28<00:00,  3.92it/s, loss=4.297]


--------------------------------------------------------------------------------
SOURCE: "Pooh! you can't be silly enough to wish to leave such a splendid place?"
TARGET: — Non siete, spero, tanto stupida da desiderare di andarvene.
PREDICTED: — Non è vero , non è stata più più più di nuovo ?
--------------------------------------------------------------------------------
SOURCE: I was obliged to...'
TARGET: Ho dovuto....
PREDICTED: Io mi .
