In [3]:
%load_ext autoreload
%autoreload 2

from mint.model.transformer import Transformer, TransformerConfig
from mint.common import create_config, to_dict
from mint.trainer import Trainer, TrainerConfig
import os
from mint.translator import Translator
from mint.dataset import Dataset
from mint.tokenizer import Tokenizer

In [4]:
DATASET_PATH = "../datasets/en_sk/"

source_tokenizer = Tokenizer.load("../datasets/en_sk/source_tokenizer/")
target_tokenizer = Tokenizer.load("../datasets/en_sk/target_tokenizer/")

dataset = Dataset.load(DATASET_PATH)


In [8]:
config = create_config(TransformerConfig())
config.glob.d_model = 512
config.glob.n_heads = 8
config.glob.max_seq_len = 128 + 1
config.glob.d_feedforward = 2048
config.glob.p_dropout = 0.1

config.encoder_config.n_blocks = 10
config.encoder_config.vocab_size = 10000 + 1
config.decoder_config.n_blocks = 10
config.decoder_config.vocab_size = 10000 + 1

model = Transformer(config)

print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

{'n_blocks': 10, 'vocab_size': 10001, 'transformer_block_config': {'d_model': 512, 'd_feedforward': 2048, 'p_dropout': 0.1, 'attention_config': {'n_heads': 8, 'd_model': 512, 'max_seq_len': 129, 'context_window': None}}, 'embedding_config': {'vocab_size': 10001, 'd_model': 512, 'max_seq_len': 129, 'learnable_positional_embeddings': True}}
Number of parameters: 126475776


In [None]:
trainer_config = create_config(TrainerConfig())
trainer_config.logger_config.experiment_name = "exp2"
trainer_config.warmup_steps = 3000
trainer_config.learning_rate = 1e-4
trainer_config.use_cuda = True
trainer_config.max_steps_per_epoch = 100000 // 32 # too big dataset to run locally
trainer_config.max_steps_per_validation = 1000 // 32
trainer = Trainer(model, dataset, **to_dict(trainer_config), source_tokenizer=source_tokenizer, target_tokenizer=target_tokenizer)
trainer.train(10)

Train 1/10 avg loss: 9.3770:   0%|          | 12/3125 [00:07<33:43,  1.54it/s]

In [17]:
translator = Translator(model, source_tokenizer, target_tokenizer)

translator.translate("Good evening", max_length=128)

['Dobré večera<|endoftext|>']