In [1]:
import time
import torch
from preprocess import TripPreProcess, ASAPPreProcess
from models import DSPN
from utils import set_seed
from trainer import DSPN_trainer
from transformers import BertModel, BertTokenizer
from transformers import RobertaTokenizer, RobertaModel
from transformers import AutoTokenizer, AutoModel
from transformers import AlbertTokenizer, AlbertModel
import warnings
warnings.filterwarnings("ignore")

In [2]:
data_name = 'Trip' # ['ASAP', 'Trip', 'rest_14', 'rest_15', 'rest_16', 'mams']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1
set_seed(seed)
n_epochs = 5
batch_size = 32

In [3]:
if data_name == 'Trip':
    trip = TripPreProcess()
    T, train_set, dev_set, test_set = trip.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-uncased")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-uncased")
    #bert_model = RobertaModel.from_pretrained("./model_params/roberta-base")
    #bert_tokenizer = RobertaTokenizer.from_pretrained("./model_params/roberta-base")
    #bert_model = AlbertModel.from_pretrained("./model_params/albert-base-v2")
    #bert_tokenizer = AlbertTokenizer.from_pretrained('./model_params/albert-base-v2')
elif data_name == 'ASAP':
    asap = ASAPPreProcess()
    T, train_set, dev_set, test_set = asap.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-chinese")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-chinese")
    #bert_model = AutoModel.from_pretrained("./model_params/roberta-base-chinese")
    #bert_tokenizer = AutoTokenizer.from_pretrained("./model_params/roberta-base-chinese")
    #bert_model = AlbertModel.from_pretrained("./model_params/albert-chinese-base/")
    #bert_tokenizer = AutoTokenizer.from_pretrained("./model_params/albert-chinese-base/")

Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transf

In [4]:
model = DSPN(T, bert_model, bert_tokenizer).to(device)
trainer = DSPN_trainer(data_name)

### Train

In [None]:
start_time = time.time()
trainer.train(model, train_set, dev_set, device=device, n_epochs=n_epochs, batch_size=batch_size, data_name=data_name, model_name='DSPN_'+str(seed))
end_time = time.time()
used_mins = (end_time - start_time) / 60
print(f"Time: {used_mins} Minutes")
num_params = sum(p.numel() for p in model.state_dict().values())
print(f"Params_size: {num_params/1000000}M")

### Test

In [6]:
model.load_state_dict(torch.load("./model_params/" + data_name + "_DSPN_"+ str(seed) +".model", map_location=device))

<All keys matched successfully>

In [7]:
trainer.test_rp(model, test_set, batch_size, device)

Precision: 0.6703865569383266
Recall: 0.6638848513096409
F1-score: 0.6593709032195875
Accuracy: 0.7228708791208791


In [8]:
trainer.test_acd(model, test_set, batch_size, device)

Th: 0.00000000 | P: 0.86455063 | R: 1.00000000 | F1: 0.92735549
Th: 0.00052632 | P: 0.86729645 | R: 0.87521988 | F1: 0.87124015
Th: 0.00105263 | P: 0.86375141 | R: 0.69715712 | F1: 0.77156404
Th: 0.00157895 | P: 0.85492979 | R: 0.63570334 | F1: 0.72919582
Th: 0.00210526 | P: 0.85355550 | R: 0.61913409 | F1: 0.71768730
Th: 0.00263158 | P: 0.85053352 | R: 0.59252114 | F1: 0.69846154
Th: 0.00315789 | P: 0.85993038 | R: 0.53265619 | F1: 0.65783664
Th: 0.00368421 | P: 0.85979401 | R: 0.50212790 | F1: 0.63399606
Th: 0.00421053 | P: 0.85656836 | R: 0.47137264 | F1: 0.60810366
Th: 0.00473684 | P: 0.85616512 | R: 0.45429269 | F1: 0.59360866
Th: 0.00526316 | P: 0.85369128 | R: 0.43307042 | F1: 0.57463389
Th: 0.00578947 | P: 0.85343635 | R: 0.41996255 | F1: 0.56292071
Th: 0.00631579 | P: 0.85668677 | R: 0.41111048 | F1: 0.55559816
Th: 0.00684211 | P: 0.85944505 | R: 0.40248539 | F1: 0.54823002
Th: 0.00736842 | P: 0.86084063 | R: 0.39630029 | F1: 0.54274168
Th: 0.00789474 | P: 0.86151731 | R: 0.39

In [9]:
trainer.test_acsa(model, test_set, batch_size, device, best_th=0)

P: 0.43828 | R: 0.50695 | F1: 0.47012
ACSA Accuracy: 0.506951143392158
      p=-1   p=0   p=1
t=-1  3656   704  1023
t=0   1817   687  1813
t=1   2164  1168  4591


In [None]:
# y, r_senti, ac_gold, ac_pred, w_senti, word_att, p_t, flag1, flag2 = trainer.output_attention(model, test_set, device, best_th=0.01551)