In [1]:
from transformer import Transformer # this is the transformer.py file
import torch
import numpy as np

In [2]:
english_file = './data/small_vocab_en.txt'
french_file = './data/small_vocab_fr.txt'
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(french_file, 'r') as file:
    french_sentences = file.readlines()

TOTAL_SENTENCE = 200000
french_sentences = french_sentences[:TOTAL_SENTENCE]
english_sentences = english_sentences[:TOTAL_SENTENCE]
english_sentences = [sentence.rstrip('\n') for sentence in english_sentences]
french_sentences = [sentence.rstrip('\n') for sentence in french_sentences]

french_sentences[:10]

["new jersey est parfois calme pendant l' automne , et il est neigeux en avril .",
 'les états-unis est généralement froid en juillet , et il gèle habituellement en novembre .',
 'california est généralement calme en mars , et il est généralement chaud en juin .',
 'les états-unis est parfois légère en juin , et il fait froid en septembre .',
 'votre moins aimé fruit est le raisin , mais mon moins aimé est la pomme .',
 "son fruit préféré est l'orange , mais mon préféré est le raisin .",
 'paris est relaxant en décembre , mais il est généralement froid en juillet .',
 'new jersey est occupé au printemps , et il est jamais chaude en mars .',
 'notre fruit est moins aimé le citron , mais mon moins aimé est le raisin .',
 'les états-unis est parfois occupé en janvier , et il est parfois chaud en novembre .']

In [3]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(french_sentences)):
    french_sentence, english_sentence = french_sentences[index], english_sentences[index]
    if is_valid_length(french_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length):
        valid_sentence_indicies.append(index)

len(valid_sentence_indicies)


137860

In [4]:
import sentencepiece as spm
import numpy as np

def english_tokenizer_load():
    sp_eng = spm.SentencePieceProcessor()
    sp_eng.Load('./eng.model')
    return sp_eng

def french_tokenizer_load():
    sp_fr = spm.SentencePieceProcessor()
    sp_fr.Load('./fr.model')
    return sp_fr

In [5]:
fr_tokenizer = french_tokenizer_load()
eng_tokenizer = english_tokenizer_load()

In [6]:
import torch

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 6
max_sequence_length = 200
chn_vocab_size = fr_tokenizer.vocab_size()

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          chn_vocab_size,
                          eng_tokenizer,
                          fr_tokenizer,
                          True,
                          True)

In [7]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(400, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): EncoderLayer(
        (attention): MultiHeadA

In [8]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, french_sentences):
        self.english_sentences = english_sentences
        self.french_sentences = french_sentences

    def __len__(self):
        assert len(english_sentences) == len(french_sentences), "different length"
        return len(self.english_sentences)
    
    def __getitem__(self, index):
        return self.english_sentences[index], self.french_sentences[index]

In [9]:
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]
french_sentences = [french_sentences[i] for i in valid_sentence_indicies]
dataset = TextDataset(english_sentences, french_sentences)
len(dataset)

137860

In [10]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [11]:
NEG_INFTY = -1e9

def create_masks(eng_batch, fr_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, fr_sentence_length = len(eng_batch[idx]), len(fr_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      fr_chars_to_padding_mask = np.arange(fr_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, fr_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, fr_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, fr_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [12]:

from torch import nn

criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
     

In [16]:
from tqdm import tqdm

transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    # Initialize tqdm progress bar
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}', leave=False)
    
    for batch_num, batch in progress_bar:
        transformer.train()
        eng_batch, fr_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, fr_batch)
        optim.zero_grad()
        fr_predictions = transformer(eng_batch,
                                     fr_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=True,
                                     enc_end_token=True,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.tokenize(fr_batch, start_token=False, end_token=True)
        loss = criterion(
            fr_predictions.view(-1, chn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == 0, False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"french Translation: {fr_batch[0]}")
            fr_sentence_predicted = torch.argmax(fr_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in fr_sentence_predicted:
              if idx == fr_tokenizer.eos_id():
                break
              if len(predicted_sentence) > 0:
                 predicted_sentence += " "
              predicted_sentence += fr_tokenizer.decode(idx.item())
            print(f"french Prediction: {predicted_sentence}")


            transformer.eval()
            fr_sentence = [""]
            eng_sentence = ["my favourite fruit is grape"]
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, fr_sentence)
                predictions = transformer(eng_sentence,
                                          fr_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=True,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = fr_tokenizer.decode(next_token_index)
                if len(fr_sentence[0]) == 0:
                  fr_sentence = [next_token]
                else:
                  fr_sentence = [fr_sentence[0] + ' ' + next_token]
                if next_token == fr_tokenizer.eos_id():
                  break
            
            print(f"Evaluation translation (my favourite fruit is grape) : {fr_sentence[0]}")
            print("-------------------------------------------")
        
        # Update the epoch loss and tqdm progress bar
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=epoch_loss/(batch_num+1))

    print(f"Epoch {epoch+1} finished with average loss: {epoch_loss/len(train_loader):.4f}")

Epoch 1/10


Epoch 1:   0%|          | 0/4596 [00:00<?, ?it/s]

Iteration 0 : 1.2086519002914429
English: new jersey is sometimes quiet during autumn , and it is snowy in april .
french Translation: new jersey est parfois calme pendant l' automne , et il est neigeux en avril .
french Prediction: new jersey est parfois sec au l ' automne , mais il est doux en mai .


Epoch 1:   0%|          | 2/4596 [00:03<1:49:21,  1.43s/it, loss=1.08]

Evaluation translation (my favourite fruit is grape) : est   es fruit aimé aimé aimé . ts e if ois f ? . . ?       . . . . l le l le                                                                                                                                                                        
-------------------------------------------


Epoch 1:   2%|▏         | 100/4596 [00:22<14:23,  5.21it/s, loss=0.506]

Iteration 100 : 0.4339633882045746
English: she plans to visit the united states next may .
french Translation: elle envisage de se rendre aux états-unis en mai prochain .
french Prediction: elle est v is a ge de se re n d re au x états - unis en août . .


Epoch 1:   2%|▏         | 102/4596 [00:25<1:02:33,  1.20it/s, loss=0.504]

Evaluation translation (my favourite fruit is grape) : elle a été à la france  e f f f el o ée  ?  ?  ?  ?  ?    .  . . . . . . . .  .                                                                                                                                                                 
-------------------------------------------


Epoch 1:   4%|▍         | 200/4596 [00:44<14:20,  5.11it/s, loss=0.456]  

Iteration 200 : 0.4150015711784363
English: he likes grapes , bananas , and apples.
french Translation: il aime les raisins , les bananes et les pommes .
french Prediction: il aime les poires , les bananes et les citrons .


Epoch 1:   4%|▍         | 202/4596 [00:47<59:29,  1.23it/s, loss=0.456]  

Evaluation translation (my favourite fruit is grape) : elle aime le v ieux it camion ? . b ?   t ?  . ?     .  . . . . . . . . . .  . . .                                                                                                                                                                  
-------------------------------------------


Epoch 1:   7%|▋         | 300/4596 [01:06<13:32,  5.29it/s, loss=0.435]

Iteration 300 : 0.4164653420448303
English: india is sometimes hot during september , but it is never snowy in summer .
french Translation: l' inde est parfois chaud en septembre , mais jamais de neige en été .
french Prediction: chine ' inde est parfois pluvieux en décembre , et il de neige en été .


Epoch 1:   7%|▋         | 302/4596 [01:09<58:25,  1.22it/s, loss=0.434]  

Evaluation translation (my favourite fruit is grape) : la mangue est votre fruit le plus aimé .  .  . . . . . . . . . .  . . .                                                                                                                                                                              
-------------------------------------------


Epoch 1:   9%|▊         | 400/4596 [01:28<13:17,  5.26it/s, loss=0.422]

Iteration 400 : 0.36911264061927795
English: new jersey is sometimes pleasant during march , but it is sometimes wet in summer .
french Translation: new jersey est parfois agréable au mois de mars , mais il est parfois humide en été .
french Prediction: new jersey est parfois humide en mois de mars , mais il est parfois s en avril .


Epoch 1:   9%|▊         | 402/4596 [01:32<59:51,  1.17it/s, loss=0.422]  

Evaluation translation (my favourite fruit is grape) : elle aime le v ieux qu camion o o . o . anc .   .  . .      . . . . . . . . .  . .  .                                                                                                                                                                 
-------------------------------------------


Epoch 1:  11%|█         | 500/4596 [01:50<13:04,  5.22it/s, loss=0.412]

Iteration 500 : 0.45974889397621155
English: france is usually quiet during summer , and it is never dry in june .
french Translation: la france est généralement calme pendant l' été , et il est jamais sec en juin .
french Prediction: la france est généralement occupé en l ' été , mais il est jamais sec en juillet .


Epoch 1:  11%|█         | 502/4596 [01:54<57:06,  1.19it/s, loss=0.413]  

Evaluation translation (my favourite fruit is grape) : la mangue est votre fruit le plus aimé .   . . . . . . . . . . . .  . . l l l l l                                                                                                                                                                         
-------------------------------------------


Epoch 1:  13%|█▎        | 600/4596 [02:13<12:45,  5.22it/s, loss=0.406]

Iteration 600 : 0.40418732166290283
English: her favorite animals are mice .
french Translation: ses animaux sont des souris préférées .
french Prediction: son s anim au x les des ch ou r is . s .


Epoch 1:  13%|█▎        | 602/4596 [02:16<59:08,  1.13it/s, loss=0.406]  

Evaluation translation (my favourite fruit is grape) : elle aime le v ieux qu camion ? o rou ?  ? . . ?         . . . . . . .  .                                                                                                                                                                       
-------------------------------------------


Epoch 1:  15%|█▌        | 700/4596 [02:35<12:20,  5.26it/s, loss=0.401]

Iteration 700 : 0.39341720938682556
English: he likes the new black truck .
french Translation: il aime le nouveau camion noir .
french Prediction: il aime la g au camion j oi r .


Epoch 1:  15%|█▌        | 702/4596 [02:39<53:05,  1.22it/s, loss=0.401]  

Evaluation translation (my favourite fruit is grape) : mon fruit préféré est le ch a v .   . . . . . . . . . . .  .  . . . . . .                                                                                                                                                                         
-------------------------------------------


                                                                       

KeyboardInterrupt: 