In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import transformers
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from nlp_models.base.data import CustomDataset, label_to_id_list
from nlp_models.multi_task_model.mtl import AutoModelForMTL
from nlp_models.multi_task_model.trainer import Trainer, Configs

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

TRAIN_BATCH_SIZE = 8
TEST_BATCH_SIZE = 4
EPOCHS = 10
MAX_LEN = 512
DATA_COL = 'question_title'
LABEL_COL = 'label'

DATA_FOLDER = Path('../../data/0_external/google-quest-challenge/')
MODEL_DIR = Path('../../models/multi-task-model/')
CHKPT_PATH = Path(MODEL_DIR / 'chkpt')
HF_MODEL_CARD = 'sentence-transformers/multi-qa-mpnet-base-dot-v1'

In [4]:
df_train = pd.read_csv(DATA_FOLDER / 'train.csv')
df_test = pd.read_csv(DATA_FOLDER / 'test.csv')

In [5]:
label_dict = dict([(v,k) for k, v in enumerate(df_train.category.unique())])
df_train[LABEL_COL] = df_train.category.apply(lambda c: label_to_id_list(c, label_dict.keys()))
df_test[LABEL_COL] = df_test.category.apply(lambda c: label_to_id_list(c, label_dict.keys()))

### Prepare training testing dataloader

In [6]:
pretrained_tokenizer = transformers.AutoTokenizer.from_pretrained(HF_MODEL_CARD)
training_set = CustomDataset(df_train, DATA_COL, LABEL_COL, pretrained_tokenizer, MAX_LEN)
testing_set = CustomDataset(df_test, DATA_COL, LABEL_COL, pretrained_tokenizer, MAX_LEN)

train_dataloader = DataLoader(
            training_set,  
            sampler=RandomSampler(training_set), 
            batch_size=TRAIN_BATCH_SIZE,
        )
test_dataloader = DataLoader(
            testing_set,
            sampler=SequentialSampler(testing_set),
            batch_size=TEST_BATCH_SIZE,
        )

### Initiate MTL model

In [7]:
model = AutoModelForMTL(HF_MODEL_CARD, len(label_dict))
model.to(DEVICE)

AutoModelForMTL(
  (base_model): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )

### Configurations for model trainer

In [8]:
configs = Configs(
    epochs = EPOCHS, 
    num_labels=len(label_dict), 
    tune_base_model=False
    )

trainer = Trainer(
    model, 
    train_dataloader,
    test_dataloader, 
    configs, 
    device=torch.device(DEVICE), 
    chkpt_dir=CHKPT_PATH
    )

### Kick off training

In [None]:
trainer.train()

### Continue training from checkpoint

In [None]:
trainer.continue_training(CHKPT_PATH / 'chkpt9.pt')

### Save the model

In [9]:
output_dir = MODEL_DIR / 'multi-task-model-finetuned-classification-layer-20230609'
tokenizer_dir = output_dir / 'tokenizer'
model_file = output_dir / 'mtl.bin'

model,_,_ = trainer.load_checkpoint(CHKPT_PATH / 'chkpt9.pt', model)
pretrained_tokenizer.save_pretrained(tokenizer_dir)
AutoModelForMTL.save_model(model, model_file)