In [1]:
import torch
from preprocess import TripPreProcess, ASAPPreProcess
from models import DSPN
from utils import set_seed
from trainer import DSPN_trainer
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_name = 'Trip' # ['ASAP', 'Trip', 'rest_14', 'rest_15', 'rest_16', 'mams']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1
set_seed(seed)
n_epochs = 10
batch_size = 32
negsize = 5

In [None]:
if data_name == 'Trip':
    trip = TripPreProcess()
    T, train_set, dev_set, test_set = trip.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-uncased")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-uncased")
elif data_name == 'ASAP':
    asap = ASAPPreProcess()
    T, train_set, dev_set, test_set = asap.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-chinese")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-chinese")

Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model = DSPN(T, bert_model, bert_tokenizer).to(device)
trainer = DSPN_trainer(data_name)

### Train

In [None]:
trainer.train(model, train_set, dev_set, device=device, n_epochs=n_epochs, batch_size=batch_size, negsize=negsize, data_name=data_name, model_name='DSPN_'+str(seed))

### Test

In [7]:
model.load_state_dict(torch.load("./model_params/" + data_name + "_DSPN_"+ str(seed) +"_5.model", map_location=device))

<All keys matched successfully>

In [8]:
trainer.test_rp(model, test_set, batch_size, device)

Precision: 0.13724816849816848
Recall: 0.3333333333333333
F1-score: 0.1944376875050677
Accuracy: 0.4117445054945055


  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
trainer.test_acd(model, test_set, batch_size, device)

In [10]:
trainer.test_acsa(model, test_set, batch_size, device, best_th=0.011)

P: 0.09787 | R: 0.01398 | F1: 0.02447
ACSA Accuracy: 0.26407967032967034
      p=-1  p=0  p=1
t=-1  5383    0    0
t=0   4317    0    0
t=1   7923    0    0


In [None]:
y, r_senti, ac_gold, ac_pred, w_senti, word_att, p_t, flag1, flag2 = trainer.output_attention(model, test_set, device, best_th=0.01551)