In [10]:
import torch
import transformers
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from bert_classifier.data import CustomDataset
from multi_task_model.mtl import AutoModelForMTL
from multi_task_model.trainer import Trainer, Configs

In [39]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

TRAIN_BATCH_SIZE = 8
TEST_BATCH_SIZE = 4
EPOCHS = 10
MAX_LEN = 512
DATA_COL = 'question_title'
LABEL_COL = 'label'

DATA_FOLDER = Path('../data/0_external/google-quest-challenge/')
MODEL_DIR = Path('../models/multitask-models/')
CHKPT_PATH = Path('../chkpt/')
HF_MODEL_CARD = 'sentence-transformers/multi-qa-mpnet-base-dot-v1'

In [37]:
df_train = pd.read_csv(DATA_FOLDER / 'train.csv')
df_test = pd.read_csv(DATA_FOLDER / 'test.csv')

In [46]:
label_dict = dict([(k,v) for k, v in enumerate(df_train.category.unique())])
label_2_id = dict((v, k) for k, v in label_dict.items())
df_train[LABEL_COL] = df_train.category.apply(lambda c: label_2_id[c])
df_test[LABEL_COL] = df_test.category.apply(lambda c: label_2_id[c])

### Prepare training testing dataloader

In [47]:
pretrained_tokenizer = transformers.AutoTokenizer.from_pretrained(HF_MODEL_CARD)
training_set = CustomDataset(df_train, DATA_COL, LABEL_COL, pretrained_tokenizer, MAX_LEN, multi_label=False)
testing_set = CustomDataset(df_test, DATA_COL, LABEL_COL, pretrained_tokenizer, MAX_LEN, multi_label=False)

train_dataloader = DataLoader(
            training_set,  
            sampler=RandomSampler(training_set), 
            batch_size=TRAIN_BATCH_SIZE,
        )
test_dataloader = DataLoader(
            testing_set,
            sampler=SequentialSampler(testing_set),
            batch_size=TEST_BATCH_SIZE,
        )

### Initiate MTL model

In [None]:
model = AutoModelForMTL(HF_MODEL_CARD, len(label_dict))
model.to(DEVICE)

### Configurations for model trainer

In [51]:
configs = Configs(
    epochs = EPOCHS, 
    multi_label=False, 
    num_labels=len(label_dict), 
    tune_base_model=False
    )

trainer = Trainer(
    model, 
    train_dataloader, test_dataloader, 
    configs, 
    device=torch.device(DEVICE), 
    chkpt_dir=CHKPT_PATH
    )

### Kick off training

In [None]:
trainer.train()

### Continue training from checkpoint

In [None]:
trainer.continue_training(CHKPT_PATH / 'chkpt9.pt')

### Save the model

In [None]:
output_dir = MODEL_DIR / 'multitask-multilabel-model-finetuned-classification-layer-20230609'
tokenizer_dir = output_dir / 'tokenizer'
model_file = output_dir / 'mtl.bin'

model,_,_ = trainer.load_checkpoint('./drive/MyDrive/chatbot/chkpt/chkpt9.pt', model)
pretrained_tokenizer.save_pretrained(tokenizer_dir)
AutoModelForMTL.save_model(model, model_file)