In [1]:
from torchinfo import summary
import warnings
import yaml

import nltk
from data import IWSLT2017DataLoader, Multi30kDataLoader
from transformer import Seq2SeqTransformer
from trainer import Trainer, EarlyStopper
from config import SharedConfig, TokenizerConfig, DataLoaderConfig, TransformerConfig, TrainerConfig
from translate import Processor
warnings.filterwarnings("ignore", category=UserWarning)

nltk.download('wordnet', download_dir='./.venv/share/nltk_data')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package wordnet to ./.venv/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [2]:
path_to_config = './configs/multi30k-small.yaml'
run_id = 'multi30k-small'
device = 'cuda'
      
with open(path_to_config) as stream:
      config = yaml.safe_load(stream)
      
tkn_conf = TokenizerConfig()
print(tkn_conf.model_dump())
      
tokenizer = {
      tkn_conf.src_language: tkn_conf.src_tokenizer,
      tkn_conf.tgt_language: tkn_conf.tgt_tokenizer
}


shared_conf = SharedConfig()
dl_conf = DataLoaderConfig(**config['dataloader'])
print(shared_conf.model_dump())
print(dl_conf.model_dump())

if dl_conf.dataset == "iwslt2017":
      dataloader = IWSLT2017DataLoader(dl_conf, tokenizer, tkn_conf, shared_conf)
else:
      dataloader = Multi30kDataLoader(dl_conf, tokenizer, tkn_conf, shared_conf)
            
vocab_transform, text_transform = dataloader.vocab_transform, dataloader.text_transform
train_dataloader, test_dataloader, val_dataloader = dataloader.train_dataloader, dataloader.test_dataloader, dataloader.val_dataloader
            
SRC_VOCAB_SIZE = len(vocab_transform[tkn_conf.src_language].index2word)
TGT_VOCAB_SIZE = len(vocab_transform[tkn_conf.tgt_language].index2word)
print(SRC_VOCAB_SIZE)
print(TGT_VOCAB_SIZE)

{'src_language': 'de', 'tgt_language': 'en', 'src_tokenizer': functools.partial(<function _spacy_tokenize at 0x7727309040d0>, spacy=<spacy.lang.de.German object at 0x772723176290>), 'tgt_tokenizer': functools.partial(<function _spacy_tokenize at 0x7727309040d0>, spacy=<spacy.lang.en.English object at 0x77271055b820>)}
{'special_symbols': ['<unk>', '<bos>', '<eos>', '<pad>']}
{'dataset': 'multi30k', 'batch_size': 128, 'num_workers': 4, 'pin_memory': True, 'drop_last': False, 'shuffle': True}
Creating DataLoaders:  |████████████████████████████████████████| 100% [3/3] in 1.3s (2.30/s) 
18643
10596


In [3]:
model_conf = TransformerConfig(
      **config['transformer'],
      src_vocab_size=SRC_VOCAB_SIZE,
      tgt_vocab_size=TGT_VOCAB_SIZE
)
print(model_conf.model_dump())

transformer = Seq2SeqTransformer(model_conf)
translator = Processor(transformer, device, shared_conf.special_symbols)

trainer_conf = TrainerConfig(
      **config['trainer'],
      device=device
)
print(trainer_conf.model_dump())

summary(transformer, [(256, dl_conf.batch_size), (256, dl_conf.batch_size), 
                      (256, 256), (256, 256), 
                      (dl_conf.batch_size, 256), (dl_conf.batch_size, 256)], depth=4)

{'num_encoder_layers': 3, 'num_decoder_layers': 3, 'emb_size': 512, 'nhead': 8, 'src_vocab_size': 18643, 'tgt_vocab_size': 10596, 'dim_feedforward': 1024, 'dropout': 0.1}
{'learning_rate': 0.0001, 'num_epochs': 200, 'batch_size': 128, 'tgt_batch_size': 128, 'num_cycles': 6}


Layer (type:depth-idx)                             Output Shape              Param #
Seq2SeqTransformer                                 [256, 128, 10596]         --
├─TokenEmbedding: 1-1                              [256, 128, 512]           --
│    └─Embedding: 2-1                              [256, 128, 512]           9,545,216
├─PositionalEncoding: 1-2                          [256, 128, 512]           --
│    └─Dropout: 2-2                                [256, 128, 512]           --
├─TokenEmbedding: 1-3                              [256, 128, 512]           --
│    └─Embedding: 2-3                              [256, 128, 512]           5,425,152
├─PositionalEncoding: 1-4                          [256, 128, 512]           --
│    └─Dropout: 2-4                                [256, 128, 512]           --
├─Transformer: 1-5                                 [256, 128, 512]           --
│    └─TransformerEncoder: 2-5                     [256, 128, 512]           --
│    │    └─ModuleLis

In [4]:
early_stopper = EarlyStopper(patience=3, min_delta=0)

trainer = Trainer(transformer, translator, train_dataloader, test_dataloader, val_dataloader, 
                  vocab_transform, early_stopper, trainer_conf, shared_conf, run_id, device)

trainer.train()
print(f'\nEvaluation: meteor_score - {trainer.evaluate(tgt_language=tkn_conf.tgt_language)}')

TEST_SEQUENCE = "Eine Gruppe Pinguine steht vor einem Iglu und lacht sich tot ."
output = translator.translate(TEST_SEQUENCE, src_language=tkn_conf.src_language, 
                              tgt_language=tkn_conf.tgt_language, text_transform=text_transform, 
                              vocab_transform=vocab_transform, special_symbols=shared_conf.special_symbols)
      
print(f'Input: {TEST_SEQUENCE}, Output: {output}')

on 0: epoch 1 avg_training_loss: 8.389605979124704
on 0: epoch 1 avg_test_loss:     7.5691564083099365
on 1: epoch 2 avg_training_loss: 6.934060313083507
on 1: epoch 2 avg_test_loss:     6.02173924446106
on 2: epoch 3 avg_training_loss: 5.571786730377762
on 2: epoch 3 avg_test_loss:     5.0379029512405396
on 3: epoch 4 avg_training_loss: 4.866412518201051
on 3: epoch 4 avg_test_loss:     4.565712630748749
on 4: epoch 5 avg_training_loss: 4.424694739006184
on 4: epoch 5 avg_test_loss:     4.127844214439392
on 5: epoch 6 avg_training_loss: 4.011514713366826
on 5: epoch 6 avg_test_loss:     3.7310657501220703
on 6: epoch 7 avg_training_loss: 3.6364861572230303
on 6: epoch 7 avg_test_loss:     3.3684552013874054
on 7: epoch 8 avg_training_loss: 3.292781162041205
on 7: epoch 8 avg_test_loss:     3.0817970633506775
on 8: epoch 9 avg_training_loss: 2.98689505899394
on 8: epoch 9 avg_test_loss:     2.845249891281128
on 9: epoch 10 avg_training_loss: 2.7174276610215506
on 9: epoch 10 avg_test_l