In [1]:
from torchinfo import summary
import warnings
import yaml

import nltk
from data import IWSLT2017DataLoader, Multi30kDataLoader
from transformer import Seq2SeqTransformer
from trainer import Trainer, EarlyStopper
from config import SharedConfig, TokenizerConfig, DataLoaderConfig, TransformerConfig, TrainerConfig
from translate import Translate
warnings.filterwarnings("ignore", category=UserWarning)

nltk.download('wordnet', download_dir='./.venv/share/nltk_data')

[nltk_data] Downloading package wordnet to ./.nltk...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
path_to_config = './configs/multi30k-small.yaml'
run_id = 'multi30k-small'
device = 'cuda'
      
with open(path_to_config) as stream:
      config = yaml.safe_load(stream)
      
tkn_conf = TokenizerConfig()
print(tkn_conf.model_dump())
      
tokenizer = {
      tkn_conf.src_language: tkn_conf.src_tokenizer,
      tkn_conf.tgt_language: tkn_conf.tgt_tokenizer
}


shared_conf = SharedConfig()
dl_conf = DataLoaderConfig(**config['dataloader'])
print(shared_conf.model_dump())
print(dl_conf.model_dump())

if dl_conf.dataset == "iwslt2017":
      dataloader = IWSLT2017DataLoader(dl_conf, tokenizer, tkn_conf, shared_conf)
else:
      dataloader = Multi30kDataLoader(dl_conf, tokenizer, tkn_conf, shared_conf)
            
vocab_transform, text_transform = dataloader.vocab_transform, dataloader.text_transform
train_dataloader, test_dataloader, val_dataloader = dataloader.train_dataloader, dataloader.test_dataloader, dataloader.val_dataloader
            
SRC_VOCAB_SIZE = len(vocab_transform[tkn_conf.src_language].index2word)
TGT_VOCAB_SIZE = len(vocab_transform[tkn_conf.tgt_language].index2word)
print(SRC_VOCAB_SIZE)
print(TGT_VOCAB_SIZE)

{'src_language': 'de', 'tgt_language': 'en', 'src_tokenizer': functools.partial(<function _spacy_tokenize at 0x7f249cfb09a0>, spacy=<spacy.lang.de.German object at 0x7f24b36a69c0>), 'tgt_tokenizer': functools.partial(<function _spacy_tokenize at 0x7f249cfb09a0>, spacy=<spacy.lang.en.English object at 0x7f24b36a6e10>)}
{'dataset': 'multi30k', 'batch_size': 32, 'num_workers': 4, 'pin_memory': True, 'drop_last': False, 'shuffle': True}
{'token_transform': {'de': functools.partial(<function _spacy_tokenize at 0x7f249cfb09a0>, spacy=<spacy.lang.de.German object at 0x7f24b36a69c0>), 'en': functools.partial(<function _spacy_tokenize at 0x7f249cfb09a0>, spacy=<spacy.lang.en.English object at 0x7f24b36a6e10>)}, 'text_transform': {}, 'vocab_transform': {}, 'dataloaders': [], 'special_symbols': ['<unk>', '<bos>', '<eos>', '<pad>']}
Sample from trainset: ('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.', 'Two young, White males are outside near many bushes.')
Sample from valset:

In [4]:
model_conf = TransformerConfig(
      **config['transformer'],
      src_vocab_size=SRC_VOCAB_SIZE,
      tgt_vocab_size=TGT_VOCAB_SIZE
)
print(model_conf.model_dump())

transformer = Seq2SeqTransformer(model_conf)
translator = Translate(transformer, device, shared_conf.special_symbols)

trainer_conf = TrainerConfig(
      **config['trainer'],
      device=device
)
print(trainer_conf.model_dump())

summary(transformer, [(256, dl_conf.batch_size), (256, dl_conf.batch_size), 
                      (256, 256), (256, 256), 
                      (dl_conf.batch_size, 256), (dl_conf.batch_size, 256)], depth=4)

{'num_encoder_layers': 6, 'num_decoder_layers': 6, 'emb_size': 512, 'nhead': 16, 'src_vocab_size': 18652, 'tgt_vocab_size': 10615, 'dim_feedforward': 512, 'dropout': 0.1, 'shared_store': {'token_transform': {'de': functools.partial(<function _spacy_tokenize at 0x7f249cfb09a0>, spacy=<spacy.lang.de.German object at 0x7f24b36a69c0>), 'en': functools.partial(<function _spacy_tokenize at 0x7f249cfb09a0>, spacy=<spacy.lang.en.English object at 0x7f24b36a6e10>)}, 'text_transform': {'de': <function BaseDataLoader.sequential_transforms.<locals>.func at 0x7f2458f7b600>, 'en': <function BaseDataLoader.sequential_transforms.<locals>.func at 0x7f244eb18ae0>}, 'vocab_transform': {'de': Vocab(), 'en': Vocab()}, 'dataloaders': [<torch.utils.data.dataloader.DataLoader object at 0x7f24b0dc1ca0>, <torch.utils.data.dataloader.DataLoader object at 0x7f249d66e510>, <torch.utils.data.dataloader.DataLoader object at 0x7f244ea7b290>], 'special_symbols': ['<unk>', '<bos>', '<eos>', '<pad>']}}
{'learning_rate':

Layer (type:depth-idx)                                  Output Shape              Param #
Seq2SeqTransformer                                      [1, 32, 10615]            --
├─TokenEmbedding: 1-1                                   [1, 32, 512]              --
│    └─Embedding: 2-1                                   [1, 32, 512]              9,549,824
├─PositionalEncoding: 1-2                               [1, 32, 512]              --
│    └─Dropout: 2-2                                     [1, 32, 512]              --
├─TokenEmbedding: 1-3                                   [1, 32, 512]              --
│    └─Embedding: 2-3                                   [1, 32, 512]              5,434,880
├─PositionalEncoding: 1-4                               [1, 32, 512]              --
│    └─Dropout: 2-4                                     [1, 32, 512]              --
├─Transformer: 1-5                                      [1, 32, 512]              --
│    └─TransformerEncoder: 2-5                

In [5]:
early_stopper = EarlyStopper(patience=3, min_delta=0)

trainer = Trainer(transformer, translator, train_dataloader, test_dataloader, val_dataloader, 
                  vocab_transform, early_stopper, trainer_conf, shared_conf, run_id, device)

trainer.train()
print(f'\nEvaluation: meteor_score - {trainer.evaluate(tgt_language=tkn_conf.tgt_language)}')

TEST_SEQUENCE = "Ein Mann mit blonden Haar hat ein Haus aus Steinen gebaut ."
output = translator.translate(TEST_SEQUENCE, src_language=tkn_conf.src_language, 
                              tgt_language=tkn_conf.tgt_language, text_transform=text_transform, 
                              vocab_transform=vocab_transform, special_symbols=shared_conf.special_symbols)
      
print(f'Input: {TEST_SEQUENCE}, Output: {tokenizer.convert_tokens_to_string(output)}')

Training: <⚠︎                                       > (!) 0/200 [0%] in 1.2s (0.00/s) 


KeyboardInterrupt: 