In [None]:
import torch
from preprocess import TripPreProcess, ASAPPreProcess
from models import DSPN
from utils import set_seed
from trainer import DSPN_trainer
from transformers import BertModel, BertTokenizer

In [None]:
data_name = 'Trip' # ['ASAP', 'Trip', 'rest_14', 'rest_15', 'rest_16', 'mams']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1
set_seed(seed)
n_epochs = 10
batch_size = 32
negsize = 5

In [None]:
if data_name == 'Trip':
    trip = TripPreProcess()
    T, train_set, dev_set, test_set = trip.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-uncased")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-uncased")
elif data_name == 'ASAP':
    asap = ASAPPreProcess()
    T, train_set, dev_set, test_set = asap.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-chinese")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-chinese")

In [None]:
model = DSPN(T, bert_model, bert_tokenizer).to(device)
trainer = DSPN_trainer(data_name)

### Train

In [None]:
trainer.train(model, train_set, dev_set, device=device, n_epochs=n_epochs, batch_size=batch_size, negsize=negsize, data_name=data_name, model_name='DSPN_'+str(seed))

### Test

In [None]:
model.load_state_dict(torch.load("./model_params/" + data_name + "_DSPN_"+ str(seed) +"_5.model", map_location=device))

In [None]:
trainer.test_rp(model, test_set, batch_size, device)

In [None]:
trainer.test_acd(model, test_set, batch_size, device)

In [None]:
trainer.test_acsa(model, test_set, batch_size, device, best_th=0.011)

In [None]:
y, r_senti, ac_gold, ac_pred, w_senti, word_att, p_t, flag1, flag2 = trainer.output_attention(model, test_set, device, best_th=0.01551)