In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import random
import copy
import itertools
from collections import OrderedDict
import yaml
import re
from yaml import Loader
import functools
import torch

from seq2seq.utils import generate_seed
from seq2seq.data.tokenizer import Tokenizer
from seq2seq.data.dataset import ParallelDataset, LazyParallelDataset,  TextDataset
from seq2seq.models.gnmt import GNMT
from seq2seq.train.loss import LabelSmoothing
from seq2seq.data import config as seq2seq_config
from seq2seq.inference.translator import Translator
from seq2seq.train.trainer import Seq2SeqTrainer

In [None]:
def load_config(config_filename):
    with open(config_filename) as f:
        config = yaml.load(f, Loader)
    
    def fix_config(sub_config, config):
        for k, val in sub_config.items():
            if isinstance(val, str) and "${" in val:
                substitutes = []
                for re_key in re.findall(r"(\$\{.+?\})", val):
                    substitutes.append(re_key)
                    re_keys = re.search(r"\$\{(.+?)\}", re_key).groups()
                    for key in re_keys:
                        subst_val = config
                        for name in key.split("@@@"):
                            subst_val = subst_val[name]
                        sub_config[k] = sub_config[k].replace(re_key, subst_val)
            elif isinstance(val, dict):
                fix_config(val, config)
    fix_config(config, config)
    os.makedirs(config['checkpoint']['save_dir'], exist_ok=True)
    
    return config

config = load_config("./config.yaml")        

In [None]:
config
# /media/mtb/1268324a-8d38-4c4f-9b71-2a4ddc231fe6/dl/nmt/en-fr/data

In [None]:
DL_PATH = os.environ.get("DL_PATH")
lang = config["setup"]["lang"]
src_lang = lang['src']
tgt_lang = lang["tgt"]
SAVEPATH = os.path.join(DL_PATH, "nmt", f"{src_lang}-{tgt_lang}", "data" )

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
device

In [None]:
torch.backends.cudnn.enabled= True

In [None]:
# TODO generat seed. Add epoch
datagen_seeds = generate_seed(config['training']["epochs"], seed=0) 
train_seed = generate_seed(1, config['training']["seed"])[0]

In [None]:
torch.manual_seed(train_seed)

In [None]:
# TODO
pad_vocab = 8

vocab_fname = config["setup"]["dataset"]['vocab']
bpe_fname = config["setup"]["dataset"]['bpe']
tokenizer = Tokenizer(vocab_fname=vocab_fname, bpe_fname=bpe_fname, lang=lang, pad=pad_vocab)

In [None]:
config['training']["max_size"]

In [None]:
train_src_fname = config['setup']['dataset']['src']['train']
train_tgt_fname = config['setup']['dataset']['tgt']['train']
train_max_len = config['training']["train_max_len"]
train_min_len = config['training']["train_min_len"]
# TODO
train_data = LazyParallelDataset(src_fname=train_src_fname,
                             tgt_fname=train_tgt_fname,
                             tokenizer=tokenizer,
                             min_len=train_min_len,
                             max_len=train_max_len,
                             sort=False,
                             max_size=config['training']["max_size"]
                             )

In [None]:
len(train_data)

In [None]:
valid_src_fname = config['setup']['dataset']['src']['valid']
valid_tgt_fname = config['setup']['dataset']['tgt']['valid']
valid_max_len = config['training']["valid_max_len"]
valid_min_len = config['training']["valid_min_len"]
valid_data = ParallelDataset(src_fname=valid_src_fname,
                             tgt_fname=valid_tgt_fname,
                             tokenizer=tokenizer,
                             min_len=valid_min_len,
                             max_len=valid_max_len,
                             sort=True)

In [None]:
test_src_fname = config['setup']['dataset']['src']['test']
test_max_len = config['training']["test_max_len"]
test_min_len = config['training']["test_min_len"]
test_data = TextDataset(src_fname=test_src_fname,
                        tokenizer=tokenizer,
                        min_len=test_min_len,
                        max_len=test_max_len,
                        sort=True,
                        max_size=1408)

In [None]:
config['setup']['dataset']['src']['test']

In [None]:
config['setup']['dataset']['tgt']['valid']

In [None]:
model_config = dict(config['model']) | {"vocab_size": tokenizer.vocab_size}

In [None]:
model_config

In [None]:
model = GNMT(**model_config).to(device)

In [None]:
model

In [None]:
def build_criterion(padding_idx, smoothing=False):
    
    if smoothing == 0:
        print("Using cross entropy loss")
        criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, size_average=False)
    else:
        print("Using smoothing Label")
        criterion = LabelSmoothing(padding_idx, smoothing=smoothing)

    return criterion
    

In [None]:
criterion = build_criterion(seq2seq_config.PAD, config['loss']['smoothing']).to(device)

In [None]:
train_loader = train_data.get_loader(batch_size=config["training"]['batch_size'],
                                     seeds=datagen_seeds,
                                     batch_first=config['model']['batch_first'],
                                     shuffle=True, 
                                     batching=config['training']['batching'],
                                     batching_opt={'num_buckets': config['training']['num_buckets']},
                                     num_workers=config['training']['num_workers'],
                                     drop_last=True
                                     )

In [None]:
valid_loader = valid_data.get_loader(batch_size=8,
                                     batch_first=config['model']['batch_first'],
                                     shuffle=False, 
                                     num_workers=0,
                                     drop_last=True
                                     )

In [None]:

test_loader = test_data.get_loader(batch_size=16,
                                     batch_first=config['model']['batch_first'],
                                     shuffle=False, 
                                     pad=True,
                                     num_workers=0,
                                     drop_last=True
                                     )

In [None]:
translator = Translator(model=model, 
                        tokenizer=tokenizer,
                        loader=test_loader,
                        beam_size=config['test']['beam_size'],
                        max_seq_len=config['training']['test_max_len'],
                        len_norm_const=config['test']['len_norm_const'],
                        len_norm_factor=config['test']['len_norm_factor'],
                        cov_penalty_factor=config['test']['cov_penalty_factor'],
                        print_freq=10,
                        reference=config['setup']['dataset']['tgt']['test']
                        )

In [None]:
train_loader.sampler.num_samples

In [None]:
# TODO
total_train_iters = len(train_loader) // config['training']['train_iter_size']  * (config["training"]['epochs'] - config['training']['start_epoch'])
total_train_iters

In [None]:
save_info = {
    "model_config": model_config,
    "config": config,
    "tokenizer": tokenizer.get_state()
}
loss_scaling = {
    "init_scale": 8192,
    "upscale_interval": 128
}
opt_config = copy.copy(config['optimizer'])
scheduler_config = config['scheduler']

trainer_options = dict(
    model=model,
    criterion=criterion,
    grad_clip=config['training']['grad_clip'],
    save_dir=config['checkpoint']['save_dir'],
    save_freq=config['checkpoint']['save_freq'],
    save_info=save_info,
    opt_config=opt_config,
    scheduler_config=scheduler_config,
    train_iterations=total_train_iters,
    iter_size=config['training']['train_iter_size'],
    keep_checkpoints=config['checkpoint']['keep_checkpoints'],
    loss_scaling=loss_scaling,
    print_freq=10,
    intra_epoch_eval=0,
    translator=translator,
    prealloc_mode="once",
    warmup=1,
    math="fp32"
)
trainer = Seq2SeqTrainer(**trainer_options)


In [None]:
if config['checkpoint']['resume']:
    checkpoint_file = config['checkpoint']['resume'] 
    if os.path.isdir(checkpoint_file):
        checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth')
    if os.path.isfile(checkpoint_file):
        trainer.load(checkpoint_file)
        trainer.optimizer.last_epoch = 0
    else:
        raise ValueError(f"No checkpoint file for {checkpoint_file}")

In [None]:
train_loss, val_loss, best_loss = [float("inf")] * 3
training_perf = []
break_training = False
test_bleu = None
start_epoch = config['training']['start_epoch']
print(f"Start epoch {start_epoch}")


In [None]:
# for epoch in range(0, config['training']['epochs']):
#     print(f"Starting epoch {epoch}")
#     train_loader.sampler.set_epoch(epoch)
#     trainer.epoch = epoch
    
#     train_loss, train_perf = trainer.optimize(train_loader)
#     # TODO
#     continue
    
#     training_perf.append(train_perf)
    
#     val_loss, val_perf = trainer.evaluate(valid_loader)
#     if val_loss < best_loss:
#         trainer.save(is_best=True)
    
    

In [None]:
for epoch in range(start_epoch, config['training']['epochs']):
    print(f"Starting epoch {epoch}")
    train_loader.sampler.set_epoch(epoch)
    trainer.epoch = epoch
    
    train_loss, train_perf = trainer.optimize(train_loader)
    training_perf.append(train_perf)
    
    val_loss, val_perf = trainer.evaluate(valid_loader)
    if val_loss < best_loss:
        trainer.save(is_best=True)
        best_loss = val_loss
    eval_fname = f'eval_epoch_{epoch}'
    eval_path = os.path.join(config['checkpoint']['save_dir'], eval_fname)
    _, eval_stats = translator.run(calc_bleu=True, epoch=epoch, eval_path=eval_path)
    test_bleu = eval_stats['bleu']
    
    acc_log = []
    acc_log += [f'Summary: Epoch: {epoch}']
    acc_log += [f'Training Loss: {train_loss:.4f}']
    acc_log += [f'Validation Loss: {val_loss:.4f}']
    acc_log += [f'Test BLEU: {test_bleu:.2f}']
    
    perf_log = []
    perf_log += [f'Performance: Epoch: {epoch}']
    # perf_log += [f'Training: {train_perf:.0f} Tok/s']
    perf_log += [f'Validation: {val_perf:.0f} Tok/s']
    
    print("*" * 100)
    print(f"Finished epoch {epoch}")
    print('\t'.join(acc_log))
    print('\t'.join(perf_log))
    print("*" * 100)
    