In [1]:
DRIVE_PATH = "/home/ikhee/Verifiable-Coherent-NLU/"
# mode = 'bert' # BERT large
mode = 'roberta' # RoBERTa large
# mode = 'roberta_mnli' # RoBERTa large pre-trained on MNLI

task_name = 'trip'
debug = False
config_batch_size = 1
config_lr = 1e-5 # Selected learning rate for best RoBERTa-based model in TRIP paper
config_epochs = 10 #10

loss_weights = [0.0, 0.4, 0.4, 0.2, 0.0] # "Omit story choice loss"
# loss_weights = [0.2, 0.4, 0.4, 0.2, 0.0] # "Omit story choice loss"

In [2]:
import os
import json
import sys
import torch
import random
import numpy as np
import spacy
!pip install jsonlines

sys.path.append(DRIVE_PATH)
!pip install 'transformers==4.2.2'
!pip install sentencepiece
!pip3 install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
!pip install deberta

if task_name in ['trip', 'ce']:
  multiple_choice = False

if mode == 'bert':
  model_name = 'bert-large-uncased'
elif mode == 'roberta':
  model_name = 'roberta-large'
elif mode == 'roberta_mnli':
  model_name = 'roberta-large-mnli'
elif mode == 'deberta':
  model_name = 'microsoft/deberta-base'
elif mode == 'deberta_large':
  model_name = 'microsoft/deberta-large'

from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer, AlbertTokenizer, T5Tokenizer, GPT2Tokenizer
from DeBERTa import deberta

if mode in ['bert']:
  tokenizer_class = BertTokenizer
elif mode in ['roberta', 'roberta_mnli']:
  tokenizer_class = RobertaTokenizer
elif mode in ['deberta', 'deberta_large']:
  tokenizer_class = DebertaTokenizer

tokenizer = tokenizer_class.from_pretrained(model_name, 
                                                do_lower_case = False, 
                                                cache_dir=os.path.join(DRIVE_PATH, 'cache'))

from transformers import BertForSequenceClassification, RobertaForSequenceClassification, DebertaForSequenceClassification, AlbertForSequenceClassification, AdamW
from transformers import BertForMultipleChoice, RobertaForMultipleChoice, AlbertForMultipleChoice, DebertaModel
from transformers import BertModel, RobertaModel, AlbertModel, DebertaModel, T5Model, T5EncoderModel, GPT2Model
from transformers import RobertaForMaskedLM
from transformers import BertConfig, RobertaConfig, DebertaConfig, AlbertConfig, T5Config, GPT2Config
from www.model.transformers_ext import DebertaForMultipleChoice
from torch.optim import Adam

if not multiple_choice:
  if mode == 'bert':
    model_class = BertForSequenceClassification
    config_class = BertConfig
    emb_class = BertModel
  elif mode in ['roberta', 'roberta_mnli']:
    model_class = RobertaForSequenceClassification
    config_class = RobertaConfig
    emb_class = RobertaModel
    lm_class = RobertaForMaskedLM
  elif mode in ['deberta', 'deberta_large']:
    model_class = DebertaForSequenceClassification
    config_class = DebertaConfig
    emb_class = DebertaModel
else:
  if mode == 'bert':
    model_class = BertForMultipleChoice
    config_class = BertConfig
    emb_class = BertModel    
  elif mode in ['roberta', 'roberta_mnli']:
    model_class = RobertaForMultipleChoice
    config_class = RobertaConfig
    emb_class = RobertaModel
    lm_class = RobertaForMaskedLM
  elif mode in ['deberta', 'deberta_large']:
    model_class = DebertaForMultipleChoice
    config_class = DebertaConfig
    emb_class = DebertaModel

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable




In [3]:
# Convert TRIP to Two-Story Classification Task
from www.utils import print_dict
import json
from collections import Counter

data_file = os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json')
with open(data_file, 'r') as f:
  cloze_dataset_2s, order_dataset_2s = json.load(f)

for p in cloze_dataset_2s:
  label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
  print('Cloze label distribution (%s):' % p)
  print(label_dist.most_common())
print_dict(cloze_dataset_2s['train'][0])

Cloze label distribution (train):
[(1, 400), (0, 399)]
Cloze label distribution (dev):
[(0, 161), (1, 161)]
Cloze label distribution (test):
[(1, 176), (0, 175)]
{
  example_id: 
    0-C0,
  stories: 
    [
      {'story_id': 0, 'worker_id': 'A1F01FVEPYCPHO', 'type': 'cloze', 'idx': 0, 'aug': False, 'actor': 'Tom', 'location': 'kitchen', 'objects': 'dustbin, microwave, pan, plate, cereal, soup', 'sentences': ['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.'], 'length': 5, 'example_id': '0-C0', 'plausible': False, 'breakpoint': 4, 'confl_sents': [3], 'confl_pairs': [[3, 4]], 'states': [{'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 6]], 'exist': [['dustbin', 4]], 'clean': [['dustbin', 0]], 'power': [['dustbin', 0]], 'functional': [['dustbin', 

In [4]:
print(len(cloze_dataset_2s["train"][0]["stories"]))
print(cloze_dataset_2s["train"][0]["stories"][0]["sentences"])
print(cloze_dataset_2s["train"][0]["stories"][0]["objects"])
print(cloze_dataset_2s["train"][0]["stories"][0]["states"])

# print(cloze_dataset_2s["train"][0]["stories"][0])
# print(cloze_dataset_2s["train"][0]["stories"][0])

2
['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.']
dustbin, microwave, pan, plate, cereal, soup
[{'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 6]], 'exist': [['dustbin', 4]], 'clean': [['dustbin', 0]], 'power': [['dustbin', 0]], 'functional': [['dustbin', 2]], 'pieces': [['dustbin', 0]], 'wet': [['dustbin', 0]], 'open': [['dustbin', 0]], 'temperature': [['dustbin', 0]], 'solid': [['dustbin', 0]], 'contain': [['dustbin', 0]], 'running': [['dustbin', 0]], 'moveable': [['dustbin', 2]], 'mixed': [['dustbin', 0]], 'edible': [['dustbin', 0]]}, {'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 0], ['plate', 6]], 'exist': [['dustbin', 2], ['plat

In [5]:
# Featurization for Tiered Classification

from www.dataset.prepro import get_tiered_data, balance_labels
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter

tiered_dataset = cloze_dataset_2s

# # Debug the code on a small amount of data
# if debug:
#   for k in tiered_dataset:
#     tiered_dataset[k] = tiered_dataset[k][:20]

# # train_spans = True
train_spans = False
# if train_spans:
#   tiered_dataset = get_story_spans_2s(tiered_dataset, train_only=True)
#   tiered_dataset['train'] = [ex for ex in tiered_dataset['train'] if ex['label'] != -1] # For now, ignore examples where both stories are plausible :(

seq_length = 16 # Max sequence length to pad to

tiered_dataset = get_tiered_data(tiered_dataset)

In [6]:
print(tiered_dataset["train"][0]["stories"][0]["entities"][0]["sentences"])
print(tiered_dataset["train"][0]["stories"][0]["entities"][0]["entity"])
print(tiered_dataset["train"][0]["stories"][0]["entities"][0]["attributes"])

print(tiered_dataset["train"][0]["stories"][0]["entities"][1]["sentences"])
print(tiered_dataset["train"][0]["stories"][0]["entities"][1]["entity"])
print(tiered_dataset["train"][0]["stories"][0]["entities"][1]["attributes"])

['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.']
fridge
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.']
soup
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0

In [7]:
tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)

tiered_tensor_dataset = {}
max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
for p in tiered_dataset:
  tiered_tensor_dataset[p] = get_tensor_dataset_tiered(tiered_dataset[p], max_story_length, add_segment_ids=True)



In [8]:
print(tiered_dataset.keys())

dict_keys(['train', 'dev', 'test'])


In [9]:
print(len(tiered_tensor_dataset["train"]))
print(len(tiered_tensor_dataset["train"][0]))
# all_inputs_ids
print(tiered_tensor_dataset["train"][0][0].shape)

# batch, all_input_ids, story, max_entities, max_sentences, seqlen
# all_inputs_ids, all_lengths, num_entities
print(tiered_tensor_dataset["train"][0][0][0][0][0])
print(tiered_tensor_dataset["train"][0][0][1][0][0])

# entity 1
print(tiered_tensor_dataset["train"][0][0][0][1][0])
# entity 2
print(tiered_tensor_dataset["train"][0][0][0][2][0])
print(tiered_tensor_dataset["train"][0][0][0][6][0])

# all_lengths
print(tiered_tensor_dataset["train"][0][1][0])

# num_entities
print(tiered_tensor_dataset["train"][0][2][0])

# all_attributes -> 거의 다 zero
print(tiered_tensor_dataset["train"][0][4][0][0][0])
print(tiered_tensor_dataset["train"][0][4][0][0][1])
print(tiered_tensor_dataset["train"][0][4][0][0][2])
print(tiered_tensor_dataset["train"][0][4][0][0][3])

# all_preconditions
print(tiered_tensor_dataset["train"][0][5][0][0][0])

# all_effects
print(tiered_tensor_dataset["train"][0][6][0][0][0])

799
10
torch.Size([2, 15, 7, 16])
tensor([0.0000e+00, 1.2997e+04, 1.5379e+04, 2.0000e+00, 2.0000e+00, 1.5691e+04,
        2.1620e+03, 1.0000e+01, 9.2000e+01, 8.4020e+03, 9.4130e+03, 1.3000e+01,
        5.0000e+00, 4.6470e+03, 4.0000e+00, 2.0000e+00], dtype=torch.float64)
tensor([0.0000e+00, 1.2997e+04, 1.5379e+04, 2.0000e+00, 2.0000e+00, 1.5691e+04,
        2.1620e+03, 1.0000e+01, 9.2000e+01, 8.4020e+03, 9.4130e+03, 1.3000e+01,
        5.0000e+00, 4.6470e+03, 4.0000e+00, 2.0000e+00], dtype=torch.float64)
tensor([0.0000e+00, 2.9000e+01, 1.8615e+04, 2.0000e+00, 2.0000e+00, 1.5691e+04,
        2.1620e+03, 1.0000e+01, 9.2000e+01, 8.4020e+03, 9.4130e+03, 1.3000e+01,
        5.0000e+00, 4.6470e+03, 4.0000e+00, 2.0000e+00], dtype=torch.float64)
tensor([0.0000e+00, 1.5691e+04, 2.0000e+00, 2.0000e+00, 1.5691e+04, 2.1620e+03,
        1.0000e+01, 9.2000e+01, 8.4020e+03, 9.4130e+03, 1.3000e+01, 5.0000e+00,
        4.6470e+03, 4.0000e+00, 2.0000e+00, 0.0000e+00], dtype=torch.float64)
tensor([0., 0.

In [10]:
# train 

from www.dataset.ann import att_to_idx, att_to_num_classes, att_types

subtask = 'cloze'
batch_sizes = [config_batch_size]
learning_rates = [config_lr]
epochs = config_epochs
eval_batch_size = 16
generate_learning_curve = False # Generate data for training curve figure in TRIP paper

num_state_labels = {}
for att in att_to_idx:
  if att_types[att] == 'default':
    num_state_labels[att_to_idx[att]] = 3
  else:
    num_state_labels[att_to_idx[att]] = att_to_num_classes[att] # Location attributes fall into this since they don't have well-define pre- and post-condition yet

# Ablation options:
# - attributes: skip attribute prediction phase
# - embeddings: DON'T input contextual embeddings to conflict detector
# - states: DON'T input states to conflict detector
# - states-labels: in states input to conflict detector, include predicted labels
# - states-logits: in states input to conflict detector, include state logits (preferred)
# - states-teacher-forcing: train conflict detector on ground truth state labels (not predictions)
# - states-attention: re-weight input to conflict detector with weights conditioned on states representation
ablation = ['attributes', 'states-logits'] # This is the default mode presented in the paper

In [11]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.train import train_epoch_tiered
from www.model.eval import evaluate_tiered, save_results, save_preds, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict, get_model_dir
from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes
import shutil
import pandas as pd

seed_val = 22 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# We'll keep the validation data here with a constant eval batch size
dev_sampler = SequentialSampler(tiered_tensor_dataset['dev'])
dev_dataloader = DataLoader(tiered_tensor_dataset['dev'], sampler=dev_sampler, batch_size=eval_batch_size)
dev_dataset_name = subtask + '_%s_dev'
dev_ids = [ex['example_id'] for ex in tiered_dataset['dev']]

all_losses = []
param_combos = []
combo_names = []
all_val_objs = []
output_dirs = []
best_obj = 0.0
best_model = '<none>'
best_dir = ''
best_obj2 = 0.0
best_model2 = '<none>'
best_dir2 = ''

print('Beginning grid search for the %s sub-task over %s parameter combination(s)!' % (subtask, str(len(batch_sizes) * len(learning_rates))))
for bs in batch_sizes:
  for lr in learning_rates:
    print('\nTRAINING MODEL: bs=%s, lr=%s' % (str(bs), str(lr)))

    loss_values = []
    obj_values = []

    # Set up training dataset with new batch size
    train_sampler = RandomSampler(tiered_tensor_dataset['train'])
    train_dataloader = DataLoader(tiered_tensor_dataset['train'], sampler=train_sampler, batch_size=bs)

    # Set up model
    config = config_class.from_pretrained(model_name,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))    
    emb = emb_class.from_pretrained(model_name,
                                          config=config,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))    
    if torch.cuda.is_available():
      emb.cuda()
    device = emb.device
    max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
    model = TieredModelPipeline(emb, max_story_length, len(att_to_num_classes), num_state_labels,
                                config_class, model_name, device, 
                                ablation=ablation, loss_weights=loss_weights).to(device)

    # Set up optimizer
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

    train_lc_data = []
    val_lc_data = []
    for epoch in range(epochs):
      # Train the model for one epoch
      print('[%s] Beginning epoch...' % str(epoch))

      epoch_loss, _ = train_epoch_tiered(model, optimizer, train_dataloader, device, seg_mode=False, 
                                         build_learning_curves=generate_learning_curve, val_dataloader=dev_dataloader, 
                                         train_lc_data=train_lc_data, val_lc_data=val_lc_data)
      
      # Save loss
      loss_values.append(epoch_loss)

      # Validate on dev set
      validation_results = evaluate_tiered(model, dev_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
      metr_attr, all_pred_atts, all_atts, \
      metr_prec, all_pred_prec, all_prec, \
      metr_eff, all_pred_eff, all_eff, \
      metr_conflicts, all_pred_conflicts, all_conflicts, \
      metr_stories, all_pred_stories, all_stories, explanations = validation_results[:16]
      explanations = add_entity_attribute_labels(explanations, tiered_dataset['dev'], list(att_to_num_classes.keys()))

      print('[%s] Validation results:' % str(epoch))
      print('[%s] Preconditions:' % str(epoch))
      print_dict(metr_prec)
      print('[%s] Effects:' % str(epoch))
      print_dict(metr_eff)
      print('[%s] Conflicts:' % str(epoch))
      print_dict(metr_conflicts)
      print('[%s] Stories:' % str(epoch))
      print_dict(metr_stories)

      # Save accuracy - want to maximize verifiability of tiered predictions
      ver = metr_stories['verifiability']
      acc = metr_stories['accuracy']
      obj_values.append(ver)
      
      # Save model checkpoint
      print('[%s] Saving model checkpoint...' % str(epoch))
      model_param_str = get_model_dir(model_name.replace('/', '-'), subtask, bs, lr, epoch) + '_' +  '-'.join([str(lw) for lw in loss_weights]) +  '_tiered_pipeline_lc'
      if train_spans:
        model_param_str += 'spans'
      if len(model.ablation) > 0:
        model_param_str += '_ablate_'
        model_param_str += '_'.join(model.ablation)
      output_dir = os.path.join(DRIVE_PATH, 'saved_models', model_param_str)
      output_dirs.append(output_dir)
      if not os.path.exists(output_dir):
        os.makedirs(output_dir)

      save_results(metr_attr, output_dir, dev_dataset_name % 'attributes')
      save_results(metr_prec, output_dir, dev_dataset_name % 'preconditions')
      save_results(metr_eff, output_dir, dev_dataset_name % 'effects')
      save_results(metr_conflicts, output_dir, dev_dataset_name % 'conflicts')
      save_results(metr_stories, output_dir, dev_dataset_name % 'stories')
      save_results(explanations, output_dir, dev_dataset_name % 'explanations')

      # Just save story preds
      save_preds(dev_ids, all_stories, all_pred_stories, output_dir, dev_dataset_name % 'stories')

      emb = emb.module if hasattr(emb, 'module') else emb
      emb.save_pretrained(output_dir)
      torch.save(model, os.path.join(output_dir, 'classifiers.pth'))
      tokenizer.save_vocabulary(output_dir)

      if ver > best_obj:
        best_obj = ver
        best_model = model_param_str
        best_dir = output_dir
      if acc > best_obj2:
        best_obj2 = acc
        best_model2 = model_param_str
        best_dir2 = output_dir        

      for od in output_dirs:
        if od != best_dir and od != best_dir2 and os.path.exists(od):
          shutil.rmtree(od)

      print('[%s] Finished epoch.' % str(epoch))

    all_losses.append(loss_values)
    all_val_objs.append(obj_values)
    param_combos.append((bs, lr))
    combo_names.append('bs=%s, lr=%s' % (str(bs), str(lr)))

print('Finished grid search! :)')
print('Best validation *verifiability* %s from model %s.' % (str(best_obj), best_model))
print('Best validation *accuracy* %s from model %s.' % (str(best_obj2), best_model2))

if generate_learning_curve:
  print('Saving learning curve data...')
  train_lc_data = [subrecord for record in train_lc_data for subrecord in record] # flatten
  val_lc_data = [subrecord for record in val_lc_data for subrecord in record] # flatten

  train_lc_data = pd.DataFrame(train_lc_data)
  print(os.path.join(best_dir if best_dir != '<none>' else best_dir2, 'learning_curve_data_train.csv'))
  train_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_train.csv'), index=False)
  val_lc_data = pd.DataFrame(val_lc_data)
  val_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_val.csv'), index=False)
  print('Learning curve data saved. %s rows saved for training, %s rows saved for validation.' % (str(len(train_lc_data.index)), str(len(val_lc_data.index))))

12082021 09:34:44|INFO|numexpr.utils| Note: NumExpr detected 40 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
12082021 09:34:44|INFO|numexpr.utils| NumExpr defaulting to 8 threads.


Beginning grid search for the cloze sub-task over 1 parameter combination(s)!

TRAINING MODEL: bs=1, lr=1e-05


[                                                                        ]   0%

[0] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[0] Validation results:
[0] Preconditions:
{
  accuracy: 
    0.9934858497174613,
  f1: 
    0.2000756590568481,
  accuracy_0: 
    0.9935553168635876,
  f1_0: 
    0.3322557471264368,
  accuracy_1: 
    0.9995796945780601,
  f1_1: 
    0.6647591851914982,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9784009713725307,
  f1_5: 
    0.10989806476487228,
  accuracy_6: 
    0.9865502264979218,
  f1_6: 
    0.6311841011425269,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9921075981880166,
  f1_8: 
    0.33201272163224477,
  accuracy_9: 
    0.9886284033064027,
  f1_9: 
    0.6358295944861531,
  accuracy_10: 
    0.9950614112922057,
  f1_10: 
    0.33250819771263823,
  accuracy_11: 
    0.9971512632

[                                                                        ]   0%

[0] Finished epoch.
[1] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[1] Validation results:
[1] Preconditions:
{
  accuracy: 
    0.9937695558772708,
  f1: 
    0.2470466437792584,
  accuracy_0: 
    0.9935553168635876,
  f1_0: 
    0.3322557471264368,
  accuracy_1: 
    0.9996263951804978,
  f1_1: 
    0.6649705180553488,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9798136645962733,
  f1_5: 
    0.15672750534430405,
  accuracy_6: 
    0.9866786531546257,
  f1_6: 
    0.632067070055966,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9921075981880166,
  f1_8: 
    0.33201272163224477,
  accuracy_9: 
    0.9885817027039648,
  f1_9: 
    0.6363471769533573,
  accuracy_10: 
    0.9950614112922057,
  f1_10: 
    0.33250819771263823,
  accuracy_11: 
    0.99715126325

[                                                                        ]   0%

[1] Finished epoch.
[2] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[2] Validation results:
[2] Preconditions:
{
  accuracy: 
    0.9941600896651567,
  f1: 
    0.2648526022652693,
  accuracy_0: 
    0.9935553168635876,
  f1_0: 
    0.3322557471264368,
  accuracy_1: 
    0.999684770933545,
  f1_1: 
    0.6652336129689316,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9820085929108485,
  f1_5: 
    0.2124986156286951,
  accuracy_6: 
    0.985604539298557,
  f1_6: 
    0.6295837967785419,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.996520805118386,
  f1_8: 
    0.5840481674919906,
  accuracy_9: 
    0.987028907672909,
  f1_9: 
    0.6319700179112812,
  accuracy_10: 
    0.9950614112922057,
  f1_10: 
    0.33250819771263823,
  accuracy_11: 
    0.997151263251296,

[                                                                        ]   0%

[2] Finished epoch.
[3] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[3] Validation results:
[3] Preconditions:
{
  accuracy: 
    0.9942896838369215,
  f1: 
    0.27029240968659207,
  accuracy_0: 
    0.9963690281604632,
  f1_0: 
    0.540969391980683,
  accuracy_1: 
    0.9989142109933218,
  f1_1: 
    0.6617588163541349,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9833162097791062,
  f1_5: 
    0.22472406366758957,
  accuracy_6: 
    0.984787278755896,
  f1_6: 
    0.6278682424707046,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9964507542147294,
  f1_8: 
    0.5936529127227002,
  accuracy_9: 
    0.98700555737169,
  f1_9: 
    0.6329523114464265,
  accuracy_10: 
    0.9950614112922057,
  f1_10: 
    0.33250819771263823,
  accuracy_11: 
    0.99715126325129

[                                                                        ]   0%

[3] Finished epoch.
[4] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[4] Validation results:
[4] Preconditions:
{
  accuracy: 
    0.9944834913370383,
  f1: 
    0.3807512187368281,
  accuracy_0: 
    0.9964741045159482,
  f1_0: 
    0.5435807793531371,
  accuracy_1: 
    0.9996730957829356,
  f1_1: 
    0.665180458461674,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9845888011955354,
  f1_5: 
    0.3315276215829998,
  accuracy_6: 
    0.9853476859851492,
  f1_6: 
    0.6293120237580562,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9964624293653388,
  f1_8: 
    0.583590042486927,
  accuracy_9: 
    0.9876243403539906,
  f1_9: 
    0.634414633862328,
  accuracy_10: 
    0.9958669966842573,
  f1_10: 
    0.4374261144468537,
  accuracy_11: 
    0.997151263251296,

[                                                                        ]   0%

[4] Finished epoch.
[5] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[5] Validation results:
[5] Preconditions:
{
  accuracy: 
    0.9949510811189465,
  f1: 
    0.4401453304330422,
  accuracy_0: 
    0.9975015177695792,
  f1_0: 
    0.5914887154984411,
  accuracy_1: 
    0.9995213188250128,
  f1_1: 
    0.6644908782195929,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9854410871900248,
  f1_5: 
    0.3684896539639628,
  accuracy_6: 
    0.9854177368888059,
  f1_6: 
    0.6295316407474897,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9970345117452015,
  f1_8: 
    0.5991910311965599,
  accuracy_9: 
    0.9877294167094756,
  f1_9: 
    0.6347550704507898,
  accuracy_10: 
    0.9961705506001027,
  f1_10: 
    0.48663949359455305,
  accuracy_11: 
    0.997151263251

[                                                                        ]   0%

[5] Finished epoch.
[6] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[6] Validation results:
[6] Preconditions:
{
  accuracy: 
    0.9949802689954701,
  f1: 
    0.45225001117488717,
  accuracy_0: 
    0.9970461868958109,
  f1_0: 
    0.5704609357607492,
  accuracy_1: 
    0.9996730957829356,
  f1_1: 
    0.6651800160100936,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9858263671601364,
  f1_5: 
    0.3774913472287672,
  accuracy_6: 
    0.9853360108345398,
  f1_6: 
    0.6294270258123578,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9969994862933732,
  f1_8: 
    0.6005551547300465,
  accuracy_9: 
    0.9876943912576472,
  f1_9: 
    0.6347216140088832,
  accuracy_10: 
    0.9963456778592444,
  f1_10: 
    0.5090506234198781,
  accuracy_11: 
    0.997735020781

[                                                                        ]   0%

[6] Finished epoch.
[7] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[7] Validation results:
[7] Preconditions:
{
  accuracy: 
    0.9950725026852847,
  f1: 
    0.4771663441207421,
  accuracy_0: 
    0.9977233456311586,
  f1_0: 
    0.6000399462692481,
  accuracy_1: 
    0.9995680194274507,
  f1_1: 
    0.6647027746597033,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9864685004436557,
  f1_5: 
    0.39558393448463225,
  accuracy_6: 
    0.985604539298557,
  f1_6: 
    0.6299754907100253,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9966725820763088,
  f1_8: 
    0.5953500752834827,
  accuracy_9: 
    0.987239060383879,
  f1_9: 
    0.6331644618703186,
  accuracy_10: 
    0.9961705506001027,
  f1_10: 
    0.5008609865900019,
  accuracy_11: 
    0.99798019894456

[                                                                        ]   0%

[7] Finished epoch.
[8] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[8] Validation results:
[8] Preconditions:
{
  accuracy: 
    0.9950783402605894,
  f1: 
    0.46291085875871363,
  accuracy_0: 
    0.997653294727502,
  f1_0: 
    0.5979893248261748,
  accuracy_1: 
    0.9996614206323261,
  f1_1: 
    0.6651268615710942,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9858263671601364,
  f1_5: 
    0.37890258594596665,
  accuracy_6: 
    0.9850908326717415,
  f1_6: 
    0.6288407628201346,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9971395881006865,
  f1_8: 
    0.6065792181109931,
  accuracy_9: 
    0.9871806846308318,
  f1_9: 
    0.6333354899687259,
  accuracy_10: 
    0.9963223275580255,
  f1_10: 
    0.518028812993228,
  accuracy_11: 
    0.9980268995470

[                                                                        ]   0%

[8] Finished epoch.
[9] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:19s.
[9] Validation results:
[9] Preconditions:
{
  accuracy: 
    0.9951192032877224,
  f1: 
    0.4655753378117338,
  accuracy_0: 
    0.997629944426283,
  f1_0: 
    0.5964332911663176,
  accuracy_1: 
    0.9995913697286695,
  f1_1: 
    0.6648082649629437,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9863984495399991,
  f1_5: 
    0.3857665224904274,
  accuracy_6: 
    0.9862466725820763,
  f1_6: 
    0.6314773589652468,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9966142063232616,
  f1_8: 
    0.5935937361382374,
  accuracy_9: 
    0.9871923597814412,
  f1_9: 
    0.6324024610902389,
  accuracy_10: 
    0.996205576051931,
  f1_10: 
    0.5080760741994763,
  accuracy_11: 
    0.997781721384205

In [12]:
print(max_story_length)

7


In [13]:
# load and text

eval_model_dir = 'roberta-large_cloze_1_1e-05_8_0.2-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits'

from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes, att_to_idx, att_types

probe_model = eval_model_dir
probe_model = os.path.join(DRIVE_PATH, 'saved_models', probe_model)

ablation = ['attributes', 'states-logits']

if 'cloze' in probe_model:
  subtask = 'cloze'
elif 'order' in probe_model:
  subtask = 'order'
  
if subtask == 'cloze':
  subtask_dataset = cloze_dataset_2s
elif subtask == 'order':
  subtask_dataset = order_dataset_2s

# Load the model
model = None
# model = torch.load(os.path.join(probe_model, 'classifiers.pth'), map_location=torch.device('cpu'))
model = torch.load(os.path.join(probe_model, 'classifiers.pth'))
if torch.cuda.is_available():
  model.cuda()
device = model.embedding.device

for layer in model.precondition_classifiers:
  layer.eval()
for layer in model.effect_classifiers:
  layer.eval()

# test

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from www.model.eval import evaluate_tiered, save_results, save_preds, list_comparison, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
metrics = [(accuracy_score, 'accuracy'), (precision_score, 'precision'), (recall_score, 'recall'), (f1_score, 'f1')]
import numpy as np
from www.utils import print_dict

print('Testing model: %s.' % probe_model)

# May alter this depending on which partition(s) you want to run inference on
for p in tiered_dataset:
  if p != 'test':
    continue

  p_dataset = tiered_dataset[p]
  p_tensor_dataset = tiered_tensor_dataset[p]
  p_sampler = SequentialSampler(p_tensor_dataset)
  p_dataloader = DataLoader(p_tensor_dataset, sampler=p_sampler, batch_size=16)
  dev_dataset_name = subtask + '_%s_' + p
  p_ids = [ex['example_id'] for ex in tiered_dataset[p]]

  # Get preds and metrics on this partition
  metr_attr, all_pred_atts, all_atts, \
  metr_prec, all_pred_prec, all_prec, \
  metr_eff, all_pred_eff, all_eff, \
  metr_conflicts, all_pred_conflicts, all_conflicts, \
  metr_stories, all_pred_stories, all_stories, explanations = evaluate_tiered(model, p_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
  explanations = add_entity_attribute_labels(explanations, tiered_dataset[p], list(att_to_num_classes.keys()))

  save_results(metr_attr, probe_model, dev_dataset_name % 'attributes')
  save_results(metr_prec, probe_model, dev_dataset_name % 'preconditions')
  save_results(metr_eff, probe_model, dev_dataset_name % 'effects')
  save_results(metr_conflicts, probe_model, dev_dataset_name % 'conflicts')
  save_results(metr_stories, probe_model, dev_dataset_name % 'stories')
  save_results(explanations, probe_model, dev_dataset_name % 'explanations')

  print('\nPARTITION: %s' % p)
  print('Stories:')
  print_dict(metr_stories)
  print('Conflicts:')
  print_dict(metr_conflicts)
  print('Preconditions:')
  print_dict(metr_prec)
  print('Effects:')
  print_dict(metr_eff)

[                                                                        ]   0%

Testing model: /home/ikhee/Verifiable-Coherent-NLU/saved_models/roberta-large_cloze_1_1e-05_8_0.2-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits.
	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:02:01s.

PARTITION: test
Stories:
{
  accuracy: 
    0.7663817663817664,
  f1: 
    0.766364699006429,
  verifiability: 
    0.08262108262108261,
}


Conflicts:
{
  accuracy: 
    0.9788284417914047,
  f1: 
    0.6620380432412823,
}


Preconditions:
{
  accuracy: 
    0.9958761060612913,
  f1: 
    0.4208797552750762,
  accuracy_0: 
    0.9986056467537949,
  f1_0: 
    0.5996283229891325,
  accuracy_1: 
    0.9991784621414251,
  f1_1: 
    0.6614106187823334,
  accuracy_2: 
    0.9991633880522769,
  f1_2: 
    0.33319383965767496,
  accuracy_3: 
    0.9993442771220549,
  f1_3: 
    0.3332240103442166,
  accuracy_4: 
    0.9998794072868147,
  f1_4: 
    0.33331323333584584,
  accuracy_5: 
    0.9891617299024706,
  f1_5: 
    0.3498233448673376,
  accuracy_6: 
    0.9850766517433184,
  f1_6: 
    0.6160098332594641,
  accuracy_7: 
    0.9984624429068873,
  f1_7: 
    0.33307687665942554,
  accuracy_8: 
    0.9970530155715341,
  f1_8: 
    

In [14]:
import json
import os

model_directories = [eval_model_dir]

partitions = ['dev', 'test']
expl_fname = 'results_cloze_explanations_%s.json'
endtask_fname = 'results_cloze_stories_%s.json'
endtask_fname_new = 'results_cloze_stories_final_%s.json'
for md in model_directories:
  for p in partitions:
    explanations = json.load(open(os.path.join(DRIVE_PATH, 'saved_models', md, expl_fname % p), 'r'))
    endtask_results = json.load(open(os.path.join(DRIVE_PATH, 'saved_models', md, endtask_fname % p), 'r'))

    consistent_preds = 0
    verifiable_preds = 0
    total = 0
    for expl in explanations:
      if expl['valid_explanation']:
        verifiable_preds += 1
      if expl['story_pred'] == expl['story_label']:
        if len(expl['conflict_pred']) == len(expl['conflict_label']) and expl['conflict_pred'][0] == expl['conflict_label'][0] and expl['conflict_pred'][1] == expl['conflict_label'][1]:
          expl['consistent'] = True
          consistent_preds += 1
        else:
          expl['consistent'] = False
      total += 1

    endtask_results['consistency'] = float(consistent_preds) / total
    print('Found %s consistent preds in %s (versus %s verifiable)' % (str(consistent_preds), p, str(verifiable_preds)))
    json.dump(explanations, open(os.path.join(DRIVE_PATH, 'saved_models', md, (expl_fname % p).replace('explanations', 'explanations_consistency')), 'w'))
    json.dump(endtask_results, open(os.path.join(DRIVE_PATH, 'saved_models', md, endtask_fname_new % p), 'w'))

Found 87 consistent preds in dev (versus 32 verifiable)
Found 80 consistent preds in test (versus 29 verifiable)


# 