## Train model

Dependency import

In [2]:
import os
import sys

Add path of project modules to visible area

In [4]:
nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

C:\Users\pynex\Projects\github\Text2Emoji\notebooks


In [39]:
import torch
from torch.optim import Adam, lr_scheduler
from torch.nn import CrossEntropyLoss
from torch.nn.functional import one_hot
from torchinfo import summary
from datasets import load_from_disk

import sys
import signal
from datetime import date
from hydra import compose, initialize
from omegaconf import OmegaConf
from loguru import logger
import tqdm

import mlflow

from src.model import Text2Emoji
from src.parser import Text2EmojiParser
from src.dataset import Text2EmojiDataset
from src.utils import print_model, seed_all, set_logger
from src.utils.train import evaluate_loss_test, print_learn_curve, evaluate_bleu

Train function

In [30]:
def train_model(model, dataset, train_cfg, emoji_vocab_size, pad_idx, path_save):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    n_epoch = train_cfg.epoch
    print_step = train_cfg.print_step
    batch_milestones = train_cfg.batch_milestones
    batch_sizes = train_cfg.batch_sizes
    epoch_emb_requires_grad = train_cfg.epoch_emb_requires_grad
    gamma = train_cfg.gamma
    lr_0 = train_cfg.lr_0
    lr_milestones = train_cfg.lr_milestones

    model.to(device=device)
    optimizer = Adam(model.parameters(), lr=lr_0)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=gamma)
    loss = CrossEntropyLoss()

    params = {
        "epochs": train_config.train_process.epoch,
        "batch_milestones": batch_milestones,
        "batch_sizes": batch_sizes,
        "epoch_emb_requires_grad": epoch_emb_requires_grad,
        "start_learning_rate": lr_0,
        "optimizer": "Adam",
        "lr_scheduler": {
            "type": "MultiStepLR",
            "milestones": lr_milestones,
            "gamma": gamma,
        },
        "loss_function": loss.__class__.__name__,
        "metric_function": 'bleu',
    }

    mlflow.log_params(params)

    history = {'train_loss': [], 'test_loss': []}
    batch_step = 0
    train_loss = 0

    test_learn_curve_increases = 0

    train_data_loader, test_data_loader = dataset.get_data_loader(batch_sizes[0], pad_idx)
    for epoch in range(n_epoch):
        if epoch == epoch_emb_requires_grad:
            model.emb_requires_grad()

        if epoch in batch_milestones:
            train_data_loader, test_data_loader = dataset.get_data_loader(batch_sizes[batch_step], pad_idx)
            batch_step += 1
        batch_size = batch_sizes[batch_step]

        logger.info(f'epoch: {epoch + 1}/{n_epoch}, '
                    f'lr: {scheduler.get_last_lr()}, '
                    f'batch_size: {batch_size}')
        for i, batch in enumerate(tqdm.tqdm(train_data_loader)):
            model.train()

            batch_en_ids = batch['en_ids']
            batch_de_ids = batch['de_ids']
            batch_en_ids = batch_en_ids.to(device=device)
            batch_de_ids = batch_de_ids.to(device=device)

            optimizer.zero_grad()

            logits = model(batch_en_ids, batch_de_ids)
            loss_t = loss(logits, one_hot(batch_de_ids.permute(1, 0)[:, 1:],
                                          num_classes=emoji_vocab_size).to(torch.float))
            loss_t.backward()
            optimizer.step()

            train_loss += loss_t.item()
            if i % print_step == 0 and i != 0:
                model.eval()

                # evaluate
                mean_train_loss = train_loss / print_step
                train_loss = 0
                mean_test_loss = evaluate_loss_test(model, test_data_loader, loss, emoji_vocab_size, device)
                mlflow.log_metric('train_loss', mean_train_loss, step=(i // print_step))
                mlflow.log_metric('test_loss', mean_test_loss, step=(i // print_step))
                logger.info(f'step: {i}/{len(train_data_loader)}, '
                            f'train_loss: {mean_train_loss}, '
                            f'test_loss: {mean_test_loss}')
                history['train_loss'].append(mean_train_loss)
                history['test_loss'].append(mean_test_loss)

                # save state
                checkpoint_path = f'{path_save}/checkpoint_{date.today()}.pth'
                torch.save({
                    'epoch': epoch,
                    'history': history,
                    'model': model.state_dict(),
                    'optim': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'batch_size': batch_size,
                    'loss': loss
                }, checkpoint_path)
                mlflow.log_artifact(checkpoint_path)
                

                # callbacks
                if len(history['test_loss']) > 1 and history['test_loss'][-2] < history['test_loss'][-1]:
                    test_learn_curve_increases += 1
                else:
                    test_learn_curve_increases = 0

                if test_learn_curve_increases > 5:
                    return history

        # calculate bleu
        results = evaluate_bleu(model, dataset, device)
        mlflow.log_metric('bleu', results, step=epoch)
        logger.info(f'bleu: {results}')

        scheduler.step()
    model.eval()
    model.to('cpu')

    return history

Set logger

In [5]:
set_logger()

Set paths

In [32]:
path_load_parser = '../data/parser'
path_load_embbeding = '../data/transfer/embbeding'
path_load_dataset = '../data/datasets/processed'
path_save_checkpoint = '../data/checkpoints'
path_save_model = '../models'
path_config="../configs"

Set configs

In [33]:
initialize(version_base=None, config_path=path_config)
cfg = compose(config_name="experiment")
print(OmegaConf.to_yaml(cfg))

In [9]:
st = cfg.processing.special_tokens
pad_token, sos_token, eos_token, unk_token = st.pad.token, st.sos.token, st.eos.token, st.unk.token
pad_idx, sos_idx, eos_idx, unk_idx = st.pad.id, st.sos.id, st.eos.id, st.unk.id

Set seed

In [10]:
seed_all(cfg.seed)

Load data

In [23]:
logger.info(f'Dataset load')
dataset = load_from_disk(path_load_dataset)

2024-11-02 17:49:44 | INFO | Dataset load


In [24]:
dataset = Text2EmojiDataset(dataset)
dataset.train_test_split(cfg.processing.data.train_test_ratio)

In [14]:
parser = Text2EmojiParser(pad_token, sos_token, eos_token, unk_token)
parser.load(path_load_parser + '/parser.pt')

In [15]:
embbedings = torch.load(path_load_embbeding + '/embbeding.pt')
embbeding_size = embbedings.shape[1]

In [16]:
logger.info('Model creating')
model = Text2Emoji(parser.text_vocab_size(), parser.emoji_vocab_size(),
                    sos_idx, eos_idx, pad_idx, embbeding_size,
                   cfg.model.hidden_size,
                   cfg.model.num_layers,
                   cfg.model.dropout,
                   cfg.model.sup_unsup_ratio)
model.init_en_emb(embbedings)

2024-11-02 17:46:12 | INFO | Model creating


In [40]:
summary(model)

Layer (type:depth-idx)                   Param #
Text2Emoji                               --
├─Encoder: 1-1                           --
│    └─Embedding: 2-1                    (1,152,600)
│    └─GRU: 2-2                          3,158,400
│    └─Linear: 2-3                       245,350
├─Decoder: 1-2                           --
│    └─Embedding: 2-4                    133,000
│    └─GRUCell: 2-5                      474,600
│    └─Linear: 2-6                       466,830
├─AttentionLayer: 1-3                    --
│    └─Linear: 2-7                       245,350
│    └─Tanh: 2-8                         --
│    └─Linear: 2-9                       351
Total params: 5,876,481
Trainable params: 4,723,881
Non-trainable params: 1,152,600

Train model

In [18]:
def signal_capture(sig, frame):
    torch.save(model.state_dict(), f'{path_save_model}/SIGINT_model_weights_{date.today()}.pth')
    sys.exit(0)

In [19]:
signal.signal(signal.SIGINT, signal_capture)

<function _signal.default_int_handler(signalnum, frame, /)>

Save on MLFlow

In [31]:
from mlflow import MlflowClient

In [34]:
mlflow.set_tracking_uri(experiment_config.mlflow_server)
mlflow.set_experiment(experiment_config.name)

In [36]:
run_name = "1.0-train-model"

In [None]:
with mlflow.start_run(run_name=run_name) as run:
    model_summary_file = f"../models/model_{date.today()}.txt"
    with open(model_summary_file, "w") as f:
        f.write(str(summary(model)))
    mlflow.log_artifact(model_summary_file)

    logger.info('Model training')
    train_history = train_model(model, dataset, cfg.train,
                                parser.emoji_vocab_size(), pad_idx,
                                path_save_checkpoint)

    torch.save(model.state_dict(), f'{path_save_model}/trained_model_weights_{date.today()}.pth')

    mlflow.pytorch.log_model(model, "model")