In [1]:
import torch
from preprocess import TripPreProcess, ASAPPreProcess
from models import ABAE
from utils import set_seed
from trainer import ABAE_trainer
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_name = 'Trip' # ['ASAP', 'Trip', 'rest_14', 'rest_15', 'rest_16', 'mams']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1
set_seed(seed)
n_epochs = 10
batch_size = 32
negsize = 5

In [3]:
if data_name == 'Trip':
    trip = TripPreProcess()
    T, train_set, dev_set, test_set = trip.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-uncased")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-uncased")
elif data_name == 'ASAP':
    asap = ASAPPreProcess()
    T, train_set, dev_set, test_set = asap.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-chinese")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-chinese")

Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ./model_params/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transf

In [4]:
model = ABAE(T, bert_model, bert_tokenizer).to(device)
trainer = ABAE_trainer(data_name)

Preparing...


### Train

In [None]:
trainer.train(model=model, train_set=train_set, dev_set=dev_set, device=device, n_epochs=n_epochs, batch_size=batch_size, negsize=negsize, data_name=data_name, model_name='ABAE_' + str(seed))

### Test

In [5]:
model = ABAE(T, bert_model, bert_tokenizer).to(device)
model.load_state_dict(torch.load("./model_params/" + data_name + "_ABAE_" + str(seed) + ".model", map_location=device))

<All keys matched successfully>

In [6]:
trainer.test_acd(model, test_set, batch_size=batch_size, device=device)

Th: 0.00001 | P: 0.86455 | R: 1.00000 | F1: 0.92736
Th: 0.00006 | P: 0.86455 | R: 1.00000 | F1: 0.92736
Th: 0.00011 | P: 0.86531 | R: 0.96385 | F1: 0.91193
Th: 0.00017 | P: 0.86501 | R: 0.95954 | F1: 0.90982
Th: 0.00022 | P: 0.86347 | R: 0.93741 | F1: 0.89893
Th: 0.00027 | P: 0.86045 | R: 0.89741 | F1: 0.87854
Th: 0.00032 | P: 0.85578 | R: 0.86200 | F1: 0.85888
Th: 0.00037 | P: 0.85536 | R: 0.85735 | F1: 0.85635
Th: 0.00043 | P: 0.86121 | R: 0.84367 | F1: 0.85235
Th: 0.00048 | P: 0.87162 | R: 0.78364 | F1: 0.82529
Th: 0.00053 | P: 0.87109 | R: 0.77836 | F1: 0.82212
Th: 0.00058 | P: 0.86922 | R: 0.76224 | F1: 0.81223
Th: 0.00064 | P: 0.86796 | R: 0.74675 | F1: 0.80281
Th: 0.00069 | P: 0.86732 | R: 0.73966 | F1: 0.79842
Th: 0.00074 | P: 0.86723 | R: 0.73909 | F1: 0.79805
Th: 0.00079 | P: 0.86731 | R: 0.73847 | F1: 0.79772
Th: 0.00084 | P: 0.86736 | R: 0.73807 | F1: 0.79751
Th: 0.00090 | P: 0.86777 | R: 0.73137 | F1: 0.79376
Th: 0.00095 | P: 0.86936 | R: 0.70953 | F1: 0.78135
Th: 0.00100 