In [1]:
import torch
import numpy as np
import pandas as pd
from preprocess import TripPreProcess, ASAPPreProcess
from trainer import RP_trainer
from models import BERT
from utils import set_seed
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_name = 'Trip' # ['ASAP', 'Trip', 'rest_14', 'rest_15', 'rest_16', 'mams']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1
set_seed(seed)
n_epochs = 10
batch_size = 32

In [3]:
if data_name == 'Trip':
    trip = TripPreProcess()
    T, train_set, dev_set, test_set = trip.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-uncased")
elif data_name == 'ASAP':
    asap = ASAPPreProcess()
    T, train_set, dev_set, test_set = asap.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-chinese")

Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias',

In [4]:
model = BERT(bert_model).to(device)
trainer = RP_trainer()

### Train

In [5]:
trainer.train(model=model, train_set=train_set, dev_set=dev_set, device=device, n_epochs=n_epochs, batch_size=batch_size, data_name=data_name, model_name='BERT_' + str(seed))

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
100%|██████████| 734/734 [02:07<00:00,  5.78it/s]


EPOCH: 1 TRAIN-F1: 0.39927499521333926 TRAIN-LOSS 0.9655647069946621 DEV-F1 0.4924243696046138


100%|██████████| 734/734 [02:03<00:00,  5.92it/s]


EPOCH: 2 TRAIN-F1: 0.46821657716370046 TRAIN-LOSS 0.8644337038172039 DEV-F1 0.6037045406830427


100%|██████████| 734/734 [02:05<00:00,  5.86it/s]


EPOCH: 3 TRAIN-F1: 0.5163019926924807 TRAIN-LOSS 0.8041903697069291 DEV-F1 0.6424856429307655


100%|██████████| 734/734 [02:06<00:00,  5.79it/s]


EPOCH: 4 TRAIN-F1: 0.5511008981247292 TRAIN-LOSS 0.759835324308528 DEV-F1 0.6497596716122036


100%|██████████| 734/734 [02:05<00:00,  5.85it/s]


EPOCH: 5 TRAIN-F1: 0.5782249555757742 TRAIN-LOSS 0.7227736123738562 DEV-F1 0.6629175049620016


100%|██████████| 734/734 [02:06<00:00,  5.81it/s]


EPOCH: 6 TRAIN-F1: 0.6037725450056436 TRAIN-LOSS 0.6874381757453828 DEV-F1 0.6524343538328259


100%|██████████| 734/734 [02:05<00:00,  5.86it/s]


EPOCH: 7 TRAIN-F1: 0.6273434988518866 TRAIN-LOSS 0.653494805036676 DEV-F1 0.6500868910890744


100%|██████████| 734/734 [02:05<00:00,  5.84it/s]


EPOCH: 8 TRAIN-F1: 0.6491056415157189 TRAIN-LOSS 0.6210301849047072 DEV-F1 0.6644514773007084


100%|██████████| 734/734 [02:03<00:00,  5.96it/s]


EPOCH: 9 TRAIN-F1: 0.6693583007275599 TRAIN-LOSS 0.5899983353015061 DEV-F1 0.6474845607541414


100%|██████████| 734/734 [02:06<00:00,  5.82it/s]


EPOCH: 10 TRAIN-F1: 0.6878992355015671 TRAIN-LOSS 0.5611291193786491 DEV-F1 0.6497498661673011


100%|██████████| 734/734 [02:07<00:00,  5.74it/s]


EPOCH: 11 TRAIN-F1: 0.7046178957709934 TRAIN-LOSS 0.5342976972209288 DEV-F1 0.6550099446981986
Early stopping at epoch 11


### Test

In [6]:
model.load_state_dict(torch.load("./model_params/" + data_name + "_BERT_" + str(seed) + ".model", map_location=device))

<All keys matched successfully>

In [7]:
trainer.test_rp(model, test_set, batch_size=batch_size, device=device)

Precision: 0.6681974972197118
Recall: 0.6685988543568341
F1-score: 0.6669256792440642
Accuracy: 0.7142857142857143
