## Train model

Dependency import

In [1]:
import os
import sys

Add path of project modules to visible area

In [2]:
nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

In [3]:
import torch
from torchinfo import summary
from datasets import load_from_disk

import sys
import signal
from datetime import date
from hydra import compose, initialize
from omegaconf import OmegaConf
from loguru import logger

import mlflow

from src.model import Text2Emoji
from src.parser import Text2EmojiParser
from src.dataset import Text2EmojiDataset
from src.train import Text2EmojiTrainer
from src.utils import seed_all, set_logger

Set logger

In [4]:
set_logger()

Set paths

In [5]:
path_config="../configs"
path_load_parser = '../data/parser'
path_load_embbeding = '../data/transfer/embbeding'
path_load_dataset = '../data/datasets/processed'
path_save_checkpoint = '../data/checkpoints'
path_save_model = '../models'

Set configs

In [6]:
initialize(version_base=None, config_path=path_config)
cfg = compose(config_name="experiment")
print(OmegaConf.to_yaml(cfg))

model:
  hidden_size: 350
  num_layers: 2
  dropout: 0.2
  sup_unsup_ratio: 0.9
processing:
  data:
    min_freq_emoji: 5
    min_freq_text: 10
    max_text_length: 128
    train_test_ratio: 0.007
  special_tokens:
    pad:
      id: 0
      token: <pad>
    sos:
      id: 1
      token: <sos>
    eos:
      id: 2
      token: <eos>
    unk:
      id: 3
      token: <unk>
train:
  epoch: 8
  batch_sizes:
  - 32
  - 64
  - 128
  - 256
  batch_milestones:
  - 2
  - 4
  - 7
  lr_0: 0.001
  lr_milestones:
  - 2
  - 4
  - 7
  gamma: 0.464159
  epoch_emb_requires_grad: 4
  print_step: 100
name: 1.0-process-data-and-train-GRU-model
mlflow_server: http://127.0.0.1:5000
seed: 42



In [7]:
st = cfg.processing.special_tokens
pad_token, sos_token, eos_token, unk_token = st.pad.token, st.sos.token, st.eos.token, st.unk.token
pad_idx, sos_idx, eos_idx, unk_idx = st.pad.id, st.sos.id, st.eos.id, st.unk.id

Set seed

In [8]:
seed_all(cfg.seed)

Load data

In [9]:
logger.info(f'Dataset load')
dataset = load_from_disk(path_load_dataset)

2024-11-19 16:43:53 | INFO | Dataset load


In [10]:
dataset = Text2EmojiDataset(dataset)
dataset.train_test_split(cfg.processing.data.train_test_ratio)

In [11]:
parser = Text2EmojiParser(pad_token, sos_token, eos_token, unk_token)
parser.load(path_load_parser + '/parser.pt')

In [12]:
embbedings = torch.load(path_load_embbeding + '/embbeding.pt')
embbeding_size = embbedings.shape[1]

In [13]:
logger.info('Model creating')
model = Text2Emoji(parser.text_vocab_size(), parser.emoji_vocab_size(),
                    sos_idx, eos_idx, pad_idx, embbeding_size,
                   cfg.model.hidden_size,
                   cfg.model.num_layers,
                   cfg.model.dropout,
                   cfg.model.sup_unsup_ratio)
model.init_en_emb(embbedings)

2024-11-19 16:43:57 | INFO | Model creating


In [14]:
summary(model)

Layer (type:depth-idx)                   Param #
Text2Emoji                               --
├─Encoder: 1-1                           --
│    └─Embedding: 2-1                    (1,152,600)
│    └─GRU: 2-2                          3,158,400
│    └─Linear: 2-3                       245,350
├─Decoder: 1-2                           --
│    └─Embedding: 2-4                    133,000
│    └─GRUCell: 2-5                      474,600
│    └─Linear: 2-6                       466,830
├─AttentionLayer: 1-3                    --
│    └─Linear: 2-7                       245,350
│    └─Tanh: 2-8                         --
│    └─Linear: 2-9                       351
Total params: 5,876,481
Trainable params: 4,723,881
Non-trainable params: 1,152,600

Train model

In [15]:
def signal_capture(sig, frame):
    torch.save(model.state_dict(), f'{path_save_model}/SIGINT_model_weights_{date.today()}.pth')
    sys.exit(0)

In [16]:
signal.signal(signal.SIGINT, signal_capture)

<function _signal.default_int_handler(signalnum, frame, /)>

In [17]:
trainer = Text2EmojiTrainer(model, cfg.train)

Save on MLFlow

In [19]:
mlflow.set_tracking_uri(cfg.mlflow_server)
mlflow.set_experiment(cfg.name)

KeyboardInterrupt: 

In [20]:
run_name = "1.0-train-model"

In [None]:
with mlflow.start_run(run_name=run_name) as run:
    model_summary_file = f"../models/model_{date.today()}.txt"
    with open(model_summary_file, "w") as f:
        f.write(str(summary(model)))
    mlflow.log_artifact(model_summary_file)

    logger.info('Model training')
    train_history = trainer.train(dataset, path_save_checkpoint)

    torch.save(model.state_dict(), f'{path_save_model}/trained_model_weights_{date.today()}.pth')

    mlflow.pytorch.log_model(model, "model")