In [None]:
import torch
import torchvision
import torchaudio

In [None]:
torch.__version__

In [None]:
#check for GPU
torch.cuda.is_available()

In [None]:
import numpy as np
import pickle
import json
import random
import time
from transformers import BertTokenizer

In [None]:
map_relations = {'Comment':0, 'Contrast':1, 'Correction':2, 'Question-answer_pair':3, 'Acknowledgement':4,'Elaboration':5,
                 'Clarification_question':6, 'Conditional':7, 'Continuation':8, 'Result':9, 'Explanation':10, 'Q-Elab':11,
                 'Alternation':12, 'Narration':13, 'Confirmation_question':14, 'Sequence':15, 'Break':16}# 

In [None]:
reverse_relations = {0:'Comment', 1:'Contrast', 2:'Correction', 3:'Question-answer_pair', 4:'Acknowledgement',5:'Elaboration',
                 6:'Clarification_question', 7:'Conditional', 8:'Continuation', 9:'Result', 10:'Explanation', 11:'Q-Elab',
                 12:'Alternation', 13:'Narration', 14:'Confirmation_question', 15:'Sequence', 16:'Break'}

In [None]:
home=%pwd
filename = home + '/data/TEST_101_bert.json'

In [None]:
from utils import load_data, input_format, position_ids_compute, tokenize
from bert_format import undersample, format_time, flat_accuracy

In [None]:
test_data = load_data(filename, map_relations)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', use_fast=True)

In [None]:
put = ['1','0']
colors = ['r', 'b', 'g', 'o', 'y', 'p']
listx = ['b', 'c', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n']
listy = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
listz = ['a', 'e', 'i', 'o', 'u', 'p', 'q', 'r', 'x', 'y', 'z']

In [None]:
coord_tokens = [''.join([s, t, i, j, k]) for s in put
                for t in colors
                for i in listx
                for j in listy
                for k in listz]

In [None]:
tokenizer.add_tokens(coord_tokens)

In [None]:
len(tokenizer)

In [None]:
device = torch.device('cuda')

In [None]:
inputs, labels_input, raw = input_format(test_data, 10, relations=True)

In [None]:
labels_input[:3]

In [None]:
num_labels = len(set([r[3] for r in labels_input]))

In [None]:
num_labels = 17

In [None]:
batch_tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True)

In [None]:
input_ids = batch_tokenized["input_ids"].to(device) # list of token ids of dialogs in batch
attention_masks = batch_tokenized["attention_mask"].to(device)
token_type_ids = batch_tokenized["token_type_ids"].to(device)

In [None]:
labels = [label[3] for label in list(labels_input)]
labels = torch.tensor(labels)
labels_relation = torch.tensor(labels_input)

In [None]:
position_ids = position_ids_compute(tokenizer, input_ids, raw, labels_relation)

In [None]:
position_ids = torch.tensor(position_ids)

In [None]:
task_ids = torch.Tensor([1 for i in range(len(input_ids))])

In [None]:
from multitask_format import MultiTaskModel, Task

In [None]:
model_path = home + '<name of your model folder>'

In [None]:
attach_task = Task(id = 0, name = 'attach prediction', type = "seq_classification", num_labels=2)
relation_task = Task(id = 1, name = 'relation prediction', type = "seq_classification", num_labels = num_labels)
tasks = [attach_task, relation_task]

model = MultiTaskModel('bert-base-cased', tasks, len(tokenizer))
output_model = model_path + '<name of your multitask .pth file output>'
# output_model = model_path + 'multitask_stac.pth'
print(output_model)
checkpoint = torch.load(output_model, map_location='cuda')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
print('loaded')

Prediction on linear predicted attachments

In [None]:
import pickle
data_path = home + '<name of your linear preds pickle file>'

In [None]:
with open(data_path, 'rb') as f:
    test_pred = pickle.load(f)

input_ids, labels, raw = input_format(test_data, 10, relations=True, attach_preds=test_pred)
# input_ids, labels, raw = input_format(test_data, 10, relations=False, attach_preds=test_pred)
batch_tokenized = tokenizer(input_ids, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True)
input_ids = batch_tokenized["input_ids"].to(device) # list of token ids of dialogs in batch
attention_masks = batch_tokenized["attention_mask"].to(device)
token_type_ids = batch_tokenized["token_type_ids"].to(device)

position_ids = position_ids_compute(tokenizer, input_ids, raw, labels)
position_ids = torch.tensor(position_ids)

task_ids = torch.Tensor([1 for i in range(len(input_ids))])

In [None]:
assert len(test_pred) == len(test_data)

In [None]:
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

In [None]:
prediction_data = TensorDataset(input_ids, attention_masks, token_type_ids, position_ids, task_ids)
prediction_sampler = SequentialSampler(prediction_data)
prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=32)


model.eval()

predictions , true_labels = [], []

for batch in prediction_dataloader:
    batch = tuple(t.to(device) for t in batch)

    b_input_ids, b_input_mask, b_token_types, b_position_ids, b_task_ids = batch

    with torch.no_grad():
        outputs, embed = model(b_input_ids,
                     token_type_ids=b_token_types,
                     attention_mask=b_input_mask,
                     position_ids = b_position_ids,
                     task_ids = b_task_ids)
    logits = outputs[0]
    logits = logits.detach().cpu().numpy()

    predictions.append(logits)

print('    DONE.')

flat_prediction = np.concatenate(predictions, axis=0)
flat_predictions = np.argmax(flat_prediction, axis=1).flatten()

In [None]:
#add predictions to test pred attachments
#make sure to keep all gold relations, even if not in predicted
preds = []
gold = []
i = 0
for n, g in enumerate(test_pred):
  pred_tmp = []
  gold_tmp = []
  for rel in test_data[n]['relations']:
    # if [rel['x'], rel['y']] in g:
    #   gold_tmp.append([rel['x'], rel['y'], rel['type']])
    if (rel['y']-rel['x']) <=10:
      gold_tmp.append([rel['x'], rel['y'], rel['type']])
  for p in g:
    f = flat_predictions[i]
    i += 1
    pred_tmp.append([p[0], p[1], f])
  gold.append(gold_tmp)
  preds.append(pred_tmp)

In [None]:
#now preds is a list of lists of all the predicted relations on predicted attachments
#and gold is a list of lists of all *gold* relations on predicted attachments
len(preds), len(gold)

In [None]:
from collections import defaultdict

In [None]:
#now we have to put them together so that in one place we have
# dialogue index | x | y | gold relation (16 if not there)|pred relation
comparisons = []
for game in list(range(len(gold))):
  gold_count = 0
  goldgame = gold[game]
  predgame = preds[game]
  true_pos = [g for g in predgame if g in goldgame]
  gold_count += len(true_pos)
  rem_gold = [r for r in goldgame if r not in true_pos]
  rem_pred = [r for r in predgame if r not in true_pos]
  assert(len(goldgame) == len(true_pos) + len(rem_gold))
  assert(len(predgame) == len(true_pos) + len(rem_pred))
  for a in true_pos:
    comparisons.append([game, a[0], a[1], a[2], a[2]])
  #now decide for FPs and FNs whether they share a set of endpoints
  rem_dict = defaultdict(list)
  for rg in rem_gold: #false neg
    rem_dict[(rg[0], rg[1])].append(('g', rg[2]))
  for rp in rem_pred: #false pos
    rem_dict[(rp[0], rp[1])].append(('p', rp[2]))

  for it in rem_dict.keys():
    p = 16
    t = 16
    for re in rem_dict[it]:
      if re[0] == 'p':
        p = re[1]
      if re[0] == 'g':
        t = re[1]
        gold_count += 1
    comparisons.append([game, it[0], it[1], t, p])
  
  assert(gold_count == (len(goldgame)))
      


In [None]:
#save multitask output 
with open(home + '<name of your pickle folder>/<name of your multitask preds pickle file>', 'wb') as f:
    pickle.dump(comparisons, f)

In [None]:
from sklearn.metrics import classification_report, ConfusionMatrixDisplay, confusion_matrix

In [None]:
#all comparisons
correct = [i[3] for i in comparisons]
predicted = [i[4] for i in comparisons]

In [None]:
corr_all = [reverse_relations[i[3]] for i in comparisons]
pred_all = [reverse_relations[i[4]] for i in comparisons]

In [None]:
print(classification_report(corr_all,pred_all))

In [None]:
cm = confusion_matrix(correct,predicted)
ConfusionMatrixDisplay(cm).plot()