In [None]:
import functools

import torch
from torch import nn
from model.model import TransformerEncoder
from model.metrices import F1Score
from HSDDataset import ViHSDData

from torch.utils.data import DataLoader, RandomSampler
from tqdm.notebook import tqdm
# from tqdm import tqdm
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
from utils import get_device

In [None]:
LR = 1e-6
n_epochs = 500
classes_num = 3
batch_size = 32
checkpoint_batch_size = 1024
max_len = 38
device = get_device()

writer = SummaryWriter()

In [None]:
import emoji
import re

def proprocess(x):
    x = str(x)
    x = emoji.replace_emoji(x, replace='')
    x = re.sub(r" +", " ", x)
    return x.lower().strip()

In [None]:
model_collate_fn = functools.partial(lambda x: x)

train_data = ViHSDData("./data/vihsd/train.csv", 
                 utterance_feild = "free_text", 
                 label_feild="label_id", 
                 text_preprocessor=proprocess
                )
train_sampler = RandomSampler(train_data)
data_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, collate_fn=model_collate_fn)


val_data = ViHSDData("./data/vihsd/dev.csv", 
                 utterance_feild = "free_text", 
                 label_feild="label_id", 
                 text_preprocessor=proprocess
                )
val_sampler = RandomSampler(val_data)
val_loader = DataLoader(val_data, batch_size=batch_size, sampler=val_sampler, collate_fn=model_collate_fn)


test_data = ViHSDData("./data/vihsd/test.csv", 
                 utterance_feild = "free_text", 
                 label_feild="label_id", 
                 text_preprocessor=proprocess
                )
test_sampler = RandomSampler(test_data)
test_loader = DataLoader(test_data, batch_size=batch_size, sampler=test_sampler, collate_fn=model_collate_fn)

training_step = len(data_loader)*n_epochs
print(f"Total {training_step} training steps for this dataset")

In [None]:
# Defining Model for specific fold "vinai/phobert-base"
model = TransformerEncoder("./weights/multiBERTuncased", classes_num = classes_num,
                            max_seq_length = max_len, 
                            checkpoint_batch_size = checkpoint_batch_size,
                            dropout_rate = 0.5, 
                            model_args = {"output_hidden_states":False}
                            )

cross_entropy_loss = nn.CrossEntropyLoss()

model.to(device)
cross_entropy_loss.to(device)


model_params = list(model.named_parameters()) # included all params from pooler and transformers
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
model_params = [{'params': [p for n, p in model_params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0001},
                {'params': [p for n, p in model_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
               ]
optimizer = AdamW(model_params, lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=training_step*0.5, 
                                            num_training_steps=training_step
                                            )

# for param in model.sent_encoder.embeddings.parameters():
#     param.requires_grad = False

In [None]:
def make_batch(batch, tokenizer, max_len=64, device="cuda:0"):
    text_list, labels = [text for text,_ in batch], [label for _,label in batch]
    labels = torch.LongTensor(labels).to(device)
    toks = tokenizer.batch_encode_plus(text_list, max_length=max_len, padding='max_length', truncation=True)
    ids, mask = (torch.LongTensor(toks["input_ids"]).to(device), torch.LongTensor(toks["attention_mask"]).to(device))
    inputs = {"input_ids": ids, "attention_mask": mask}
    return inputs, labels

In [None]:
def train(data_loader, model, cross_entropy_loss, optimizer, scheduler, train_step):
    tk = tqdm(data_loader)
    f1 = F1Score()
    for x in tk:
        model.train()
        inputs, labels = make_batch(x, model.tokenizer, max_len=max_len, device=device)
        optimizer.zero_grad()
        logits = model(inputs)
        loss = cross_entropy_loss(logits, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            model.eval()
            predict = torch.argmax(torch.softmax(logits, dim=-1))
            f1_socre = f1(predict, labels, "macro")
            tk.set_postfix(Epoch=e, step=train_step, loss=loss.data.item(), f1=f1_socre[0].data.item())
        train_step += 1
        break
    return train_step


def evaluation(dev_loader, model, cross_entropy_loss, dev_step):
    tk = tqdm(dev_loader)
    f1 = F1Score()
    model.eval()
    with torch.no_grad():
        for x in tk:
            inputs, labels = make_batch(x, model.tokenizer, max_len=max_len, device=device)
            logits = model(inputs)
            loss = cross_entropy_loss(logits, labels)
            predict = torch.argmax(torch.softmax(logits, dim=-1))
            f1_socre = f1(predict, labels, "macro")
            tk.set_postfix(Epoch=e, step=dev_step, loss=loss.data.item(), f1=f1_socre[0].data.item())
            dev_step += 1
    return dev_step

In [None]:
train_step, val_step = 0, 0
for e in range(n_epochs):
    train_step = train(data_loader, model, cross_entropy_loss, optimizer, scheduler, train_step)
    val_step = evaluation(val_loader, model, cross_entropy_loss, val_step)