In [None]:
import time
import torch
from preprocess import TripPreProcess, ASAPPreProcess
from models import DSPN
from utils import set_seed
from trainer import DSPN_trainer
from transformers import BertModel, BertTokenizer
from transformers import RobertaTokenizer, RobertaModel
from transformers import AutoTokenizer, AutoModel
from transformers import AlbertTokenizer, AlbertModel
# import warnings
# warnings.filterwarnings("ignore")

In [None]:
data_name = 'Trip' # ['ASAP', 'Trip', 'rest_14', 'rest_15', 'rest_16', 'mams']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1
set_seed(seed)
n_epochs = 5
batch_size = 32

In [None]:
if data_name == 'Trip':
    trip = TripPreProcess()
    T, train_set, dev_set, test_set = trip.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-uncased")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-uncased")
    #bert_model = RobertaModel.from_pretrained("./model_params/roberta-base")
    #bert_tokenizer = RobertaTokenizer.from_pretrained("./model_params/roberta-base")
    #bert_model = AlbertModel.from_pretrained("./model_params/albert-base-v2")
    #bert_tokenizer = AlbertTokenizer.from_pretrained('./model_params/albert-base-v2')
elif data_name == 'ASAP':
    asap = ASAPPreProcess()
    T, train_set, dev_set, test_set = asap.get_dataset()
    bert_model = BertModel.from_pretrained("./model_params/bert-base-chinese")
    bert_tokenizer = BertTokenizer.from_pretrained("./model_params/bert-base-chinese")
    #bert_model = AutoModel.from_pretrained("./model_params/roberta-base-chinese")
    #bert_tokenizer = AutoTokenizer.from_pretrained("./model_params/roberta-base-chinese")
    #bert_model = AlbertModel.from_pretrained("./model_params/albert-chinese-base/")
    #bert_tokenizer = AutoTokenizer.from_pretrained("./model_params/albert-chinese-base/")

In [None]:
model = DSPN(T, bert_model, bert_tokenizer).to(device)
trainer = DSPN_trainer(data_name)

### Train

In [None]:
num_params = sum(p.numel() for p in model.state_dict().values())
print(f"Params_size: {num_params/1000000}M")

In [None]:
start_time = time.time()
trainer.train(model, train_set, dev_set, device=device, n_epochs=n_epochs, batch_size=batch_size, data_name=data_name, model_name='DSPN_'+str(seed))
end_time = time.time()
used_mins = (end_time - start_time) / 60
print(f"Time: {used_mins} Minutes")

### Test

In [None]:
model.load_state_dict(torch.load("./model_params/" + data_name + "_DSPN_"+ str(seed) +".model", map_location=device))

In [None]:
trainer.test_rp(model, test_set, batch_size, device)

In [None]:
trainer.test_acd(model, test_set, batch_size, device)

In [None]:
trainer.test_acsa(model, test_set, batch_size, device, best_th=1e-4)