### **Toward Consistent, Verifiable, and Coherent Commonsense Reasoning in Large LMs**

This notebook provides source code for our two papers in Findings of EMNLP 2021:


1.  Shane Storks, Qiaozi Gao, Yichi Zhang, and Joyce Y. Chai (2021). *Tiered Reasoning for Intuitive Physics: Toward Verifiable Commonsense Language Understanding.* Findings of EMNLP 2021.
2.   Shane Storks and Joyce Y. Chai (2021). *Beyond the Tip of the Iceberg: Assessing Coherence of Text Classifiers.* Findings of EMNLP 2021.

*If you have any questions or problems, please open an issue on our [GitHub repo](https://github.com/sled-group/Verifiable-Coherent-NLU) or email Shane Storks.*

***First, configure the execution mode by selecting a few settings (expand cell if needed):***




   0. (Colab only) Insert the path in your Google Drive to the folder where this notebook is located.

In [1]:
DRIVE_PATH = './'

1.   Model type (choose from BERT large, RoBERTa large, RoBERTa large + MNLI, DeBERTa base, and DeBERTa large).






In [2]:
#mode = 'bert' # BERT large
mode = 'roberta' # RoBERTa large
# mode = 'roberta_mnli' # RoBERTa large pre-trained on MNLI
# mode = 'deberta' # DeBERTa base for training on TRIP
# mode = 'deberta_large' # DeBERTa large for training on CE and ART

2.   Name of the task we want to train or evaluate on. Set `debug` to `True` to run quick training/evaluation jobs on only a small amount of data.

In [3]:
task_name = 'trip'
# task_name = 'ce'
# task_name = 'art'

debug = False

3.   (If training models) Training batch size, learning rate, and maximum number of epochs. Settings for results in the paper are provided as examples.

In [4]:
config_batch_size = 1
config_lr = 1e-5 # Selected learning rate for best RoBERTa-based model in TRIP paper
config_epochs = 10

4.   (For training TRIP models only) Configure the loss weighting scheme for training models here. We provide the 4 modes from the paper as examples.


In [5]:
# Loss weights for (attributes, preconditions, effects, conflicts, story choices)
if task_name != 'trip':
  print("We do not need a loss weighting scheme for %s dataset. Ignoring this cell." % task_name)
# loss_weights = [0.0, 0.4, 0.4, 0.1, 0.1] # "All losses"
loss_weights = [0.0, 0.4, 0.4, 0.2, 0.0] # "Omit story choice loss"
# loss_weights = [0.0, 0.4, 0.4, 0.0, 0.2] # "Omit conflict detection loss"
# loss_weights = [0.0, 0.0, 0.0, 0.5, 0.5] # "Omit state classification losses"

   5. (If evaluating models) Provide the name of the pre-trained model directory here. This should be the name of a directory within the *saved_models* directory, which should be located where this notebook is. Names of provided pre-trained model directories are listed.

In [6]:
# TRIP, all losses
# eval_model_dir = 'bert-large-uncased_cloze_1_5e-06_4_0.0-0.4-0.4-0.1-0.1_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_1e-05_7_0.0-0.4-0.4-0.1-0.1_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'microsoft-deberta-base_cloze_1_5e-06_5_0.0-0.4-0.4-0.1-0.1_tiered_pipeline_ablate_attributes_states-logits'

# TRIP, no story classification loss
# eval_model_dir = 'bert-large-uncased_cloze_1_5e-05_8_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_ablate_attributes_states-logits'
#eval_model_dir = 'roberta-large_cloze_1_1e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits' # Best model trained in the TRIP paper
# eval_model_dir = 'microsoft-deberta-base_cloze_1_5e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_ablate_attributes_states-logits'

# TRIP, no conflict detection loss
# eval_model_dir = 'bert-large-uncased_cloze_1_1e-06_1_0.0-0.4-0.4-0.0-0.2_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_5e-06_8_0.0-0.4-0.4-0.0-0.2_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'microsoft-deberta-base_cloze_1_1e-06_3_0.0-0.4-0.4-0.0-0.2_tiered_pipeline_ablate_attributes_states-logits'

# TRIP, no physical state classification loss
# eval_model_dir = 'bert-large-uncased_cloze_1_1e-05_3_0.0-0.0-0.0-0.5-0.5_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_1e-06_7_0.0-0.0-0.0-0.5-0.5_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'microsoft-deberta-base_cloze_1_5e-06_9_0.0-0.0-0.0-0.5-0.5_tiered_pipeline_ablate_attributes_states-logits'

# CE
# eval_model_dir = 'bert-large-uncased_ConvEnt_32_7.5e-06_7_xval'
# eval_model_dir = 'roberta-large_ConvEnt_32_7.5e-06_9_xval'
# eval_model_dir = 'roberta-large-mnli_ConvEnt_32_7.5e-06_7_xval'
# eval_model_dir = 'microsoft-deberta-large_ConvEnt_16_1e-05_9_xval'

# ART
# eval_model_dir = 'bert-large-uncased_art_64_5e-06_8'
# eval_model_dir = 'roberta-large_art_64_2.5e-06_4'
# eval_model_dir = 'DeBERTa-deberta-large_art_32_1e-06_8'

**For more configuration options, scroll down to the Train Models > Configure Hyperparameters cell for the task you're working on.**

# Setup
Run this block every time when starting up the notebook. It will get Colab ready, preprocess the data, and load model packages and classes we'll need later. May take several minutes to run for the first time.

**If you get a `ModuleNotFoundError` for the `www` code base, try the following:**


1.   Ensure the DRIVE_PATH is set properly above.
2.   (Colab only) Verify that this notebook has access to your Google Drive (click the folder icon on the left and then the Google Drive icon).
2.   Try to restart the runtime and refresh your browser window.
2.   (Colab only) If the problem persists, revoke access to Google Drive and re-enable it.





## Colab Setup

Enable auto reloading of code libraries from Google Drive, set up connection to Google Drive, and import some packages. 🔌

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
# Install a pip package in the current Jupyter kernel
import sys
!{sys.executable} -m pip install numpy
import numpy



In [9]:
import os
import json
import sys
import torch
import random
import numpy as np
import spacy
!{sys.executable} -m pip install jsonlines

sys.path.append(DRIVE_PATH)



## Model Setup

Next, we'll load up the transformer model, tokenizer, etc. ⏳

### Install HuggingFace transformers and other dependencies

In [10]:
#first time
'''
%pip install 'transformers==4.2.2'
%pip install sentencepiece
%pip install --upgrade setuptools
%pip install --upgrade pip
%pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
%pip install Cython
%pip install deberta
'''

"\n%pip install 'transformers==4.2.2'\n%pip install sentencepiece\n%pip install --upgrade setuptools\n%pip install --upgrade pip\n%pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2\n%pip install Cython\n%pip install deberta\n"

In [11]:
#%pip install ipywidgets
#%pip install matplotlib

### Get Model Components

Specify which model parameters from transformers we want to use:

In [12]:
if task_name in ['trip', 'ce']:
  multiple_choice = False
elif task_name == 'art':
  multiple_choice = True
else:
  raise ValueError("Task name should be set to 'trip', 'ce', or 'art' in the first cell of the notebook!")

if mode == 'bert':
  model_name = 'bert-large-uncased'
elif mode == 'roberta':
  model_name = 'roberta-large'
elif mode == 'roberta_mnli':
  model_name = 'roberta-large-mnli'
elif mode == 'deberta':
  model_name = 'microsoft/deberta-base'
elif mode == 'deberta_large':
  model_name = 'microsoft/deberta-large'

Load the tokenizer:

In [13]:
from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer, AlbertTokenizer, T5Tokenizer, GPT2Tokenizer

from DeBERTa import deberta
if mode in ['bert']:
  tokenizer_class = BertTokenizer
elif mode in ['roberta', 'roberta_mnli']:
  tokenizer_class = RobertaTokenizer
elif mode in ['deberta', 'deberta_large']:
  tokenizer_class = DebertaTokenizer

tokenizer = tokenizer_class.from_pretrained(model_name, 
                                                do_lower_case = False, 
                                                cache_dir=os.path.join(DRIVE_PATH, 'cache'))

Load the model and optimizer:



In [14]:
from transformers import BertForSequenceClassification, RobertaForSequenceClassification, DebertaForSequenceClassification, AlbertForSequenceClassification, AdamW
from transformers import BertForMultipleChoice, RobertaForMultipleChoice, AlbertForMultipleChoice, DebertaModel
from transformers import BertModel, RobertaModel, AlbertModel, DebertaModel, T5Model, T5EncoderModel, GPT2Model
from transformers import RobertaForMaskedLM
from transformers import BertConfig, RobertaConfig, DebertaConfig, AlbertConfig, T5Config, GPT2Config
from www.model.transformers_ext import DebertaForMultipleChoice
from torch.optim import Adam
if not multiple_choice:
  if mode == 'bert':
    model_class = BertForSequenceClassification
    config_class = BertConfig
    emb_class = BertModel
  elif mode in ['roberta', 'roberta_mnli']:
    model_class = RobertaForSequenceClassification
    config_class = RobertaConfig
    emb_class = RobertaModel
    lm_class = RobertaForMaskedLM
  elif mode in ['deberta', 'deberta_large']:
    model_class = DebertaForSequenceClassification
    config_class = DebertaConfig
    emb_class = DebertaModel
else:
  if mode == 'bert':
    model_class = BertForMultipleChoice
    config_class = BertConfig
    emb_class = BertModel    
  elif mode in ['roberta', 'roberta_mnli']:
    model_class = RobertaForMultipleChoice
    config_class = RobertaConfig
    emb_class = RobertaModel
    lm_class = RobertaForMaskedLM
  elif mode in ['deberta', 'deberta_large']:
    model_class = DebertaForMultipleChoice
    config_class = DebertaConfig
    emb_class = DebertaModel

## Data Setup

Preprocess the dataset.

### Preprocessing

Construct the dataset from the .txt files collected from AMT. Save a backup copy in Drive.

In [15]:
from www.utils import print_dict

partitions = ['train', 'dev', 'test']
subtasks = ['cloze', 'order']

# We can split the data into multiple json files later
data_file = os.path.join(DRIVE_PATH, 'all_data/www.json')
with open(data_file, 'r') as f:
  dataset = json.load(f)

print('Preprocessed examples:')
for ex_idx in [0,1,5,10]:
  ex = dataset['dev'][list(dataset['dev'].keys())[ex_idx]]
  print_dict(ex)

Preprocessed examples:
{
  story_id: 
    13,
  worker_id: 
    A32W24TWSWXW,
  type: 
    None,
  idx: 
    None,
  aug: 
    False,
  actor: 
    John,
  location: 
    kitchen,
  objects: 
    cabinet, counter, knife, pan, potato, pizza,
  sentences: 
    [
      John was getting the snacks ready for the party.
      John opened the cabinet, took out a pan and put it on the counter.
      John opened the fridge and got out the pizza.
      John put the pizza on the pan and put them into the oven.
      John took a knife and cut the hot pizza in eight slices.
    ],
  length: 
    5,
  example_id: 
    13,
  plausible: 
    True,
  breakpoint: 
    -1,
  confl_sents: 
    [],
  confl_pairs: 
    [],
  states: 
    [
      {'h_location': [['John', 0]], 'conscious': [['John', 2]], 'wearing': [['John', 0]], 'h_wet': [['John', 0]], 'hygiene': [['John', 0]], 'location': [['snacks', 0], ['party', 0]], 'exist': [['snacks', 4], ['party', 2]], 'clean': [['snacks', 0], ['party', 0]], 'power': 

In [16]:
cloze_dataset = {p: [] for p in dataset}
order_dataset = {p: [] for p in dataset}

for p in dataset:
  for exid in dataset[p]:
    ex = dataset[p][exid]

    if ex['type'] == None:
      continue
    
    ex_plaus = dataset[p][str(ex['story_id'])]

    if ex['type'] == 'cloze':
      cloze_dataset[p].append(ex)
      cloze_dataset[p].append(ex_plaus) # For every implausible story, add a copy of its corresponding plausible story

    # Exclude augmented ordering examples from dev and test, since the breakpoints aren't always accurate in those
    elif ex['type'] == 'order' and not (p != 'train' and ex['aug']): 
      order_dataset[p].append(ex)
      order_dataset[p].append(ex_plaus)



### Convert TRIP to Two-Story Classification Task

Ready the TRIP dataset for two-story classification.

In [17]:
from www.utils import print_dict
import json
from collections import Counter

data_file = os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json')
with open(data_file, 'r') as f:
  cloze_dataset_2s, order_dataset_2s = json.load(f)  

for p in cloze_dataset_2s:
  label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
  print('Cloze label distribution (%s):' % p)
  print(label_dist.most_common())
print_dict(cloze_dataset_2s['train'][0])

Cloze label distribution (train):
[(1, 400), (0, 399)]
Cloze label distribution (dev):
[(0, 161), (1, 161)]
Cloze label distribution (test):
[(1, 176), (0, 175)]
{
  example_id: 
    0-C0,
  stories: 
    [
      {'story_id': 0, 'worker_id': 'A1F01FVEPYCPHO', 'type': 'cloze', 'idx': 0, 'aug': False, 'actor': 'Tom', 'location': 'kitchen', 'objects': 'dustbin, microwave, pan, plate, cereal, soup', 'sentences': ['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.'], 'length': 5, 'example_id': '0-C0', 'plausible': False, 'breakpoint': 4, 'confl_sents': [3], 'confl_pairs': [[3, 4]], 'states': [{'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 6]], 'exist': [['dustbin', 4]], 'clean': [['dustbin', 0]], 'power': [['dustbin', 0]], 'functional': [['dustbin', 

---

# TRIP Results

Contains code for the tiered and random TRIP baselines.

In [18]:
if task_name != 'trip':
  raise ValueError('Please configure task_name in first cell to "trip" to run TRIP results!')

## Transformer-Based Tiered Classifier for TRIP

This is the baseline model presented in the paper. Based on the settings above, the below cells can be used for training and evaluating models.


### Featurization for Tiered Classification

Get the data ready for input to the model.
If you want to use feature augmentation, set Feature_augmentation to be true

In [19]:
from www.dataset.ann import att_to_idx, att_change_dir, att_types, att_default_values
from www.dataset.prepro import get_tiered_data, balance_labels
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter
import spacy
feature_augmentation=False
tiered_dataset = cloze_dataset_2s
train_spans = False
debug_mode=False
if(debug_mode):
    tiered_dataset['train']=tiered_dataset['train'][:20]
    tiered_dataset['dev']=tiered_dataset['dev'][:20]
if(not feature_augmentation):
    # train_spans = True
    train_spans = False
    if train_spans:
      tiered_dataset = get_story_spans_2s(tiered_dataset, train_only=True)
      tiered_dataset['train'] = [ex for ex in tiered_dataset['train'] if ex['label'] != -1] # For now, ignore examples where both stories are plausible :(

    seq_length = 16 # Max sequence length to pad to

    tiered_dataset = get_tiered_data(tiered_dataset)
    tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)

    tiered_tensor_dataset = {}
    max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
    for p in tiered_dataset:
      tiered_tensor_dataset[p] = get_tensor_dataset_tiered(tiered_dataset[p], max_story_length, add_segment_ids=True)
else:
    dataset=tiered_dataset
  # states = {p: [] for p in dataset}
    max_story_length = max([len(ex['sentences']) for p in dataset for ex_2s in dataset[p] for ex in ex_2s['stories']])
    nlp = spacy.load("en_core_web_sm")
    for p in dataset:
        for ex_2s in dataset[p]:
          for s_idx, ex in enumerate(ex_2s['stories']):
            if 'states' in ex:
              ent_sent_examples = {}
              all_entities = set()
              for i, sent_ann in enumerate(ex['states']):
                #each sentence
                entities = []
                entity_anns = {}
                noun_to_verb={}
                doc = nlp(ex['sentences'][i])
                for chunk in doc.noun_chunks:
                    if(chunk.root.dep_=='dobj'):
                        noun_to_verb[chunk.root.text]=chunk.root.head.text
                #create a dict which relates verbs and entity
                for att in sent_ann:
                #for each attribute
                  for ent, v in [tuple(ann) for ann in sent_ann[att]]:
                    #entity, label
                    entities.append(ent)
                    ent_verb=noun_to_verb.get(ent)

                    if(ent_verb!=None):
                        entities.append(ent_verb)
                    if ent not in entity_anns:
                      entity_anns[ent] = [[0] * len(att_to_idx), [0] * len(att_to_idx)] # pre/post condition, then value for each attribute
                    if(ent_verb not in entity_anns):
                      entity_anns[ent_verb] = [[0] * len(att_to_idx), [0] * len(att_to_idx)] # pre/post condition, then value for each attribute  
                    if 'location' not in att:
                      entity_anns[ent][0][att_to_idx[att]] = att_change_dir['default'][v][0] + 1
                      entity_anns[ent][1][att_to_idx[att]] = att_change_dir['default'][v][1] + 1
                      entity_anns[ent_verb][0][att_to_idx[att]] = att_change_dir['default'][v][0] + 1
                      entity_anns[ent_verb][1][att_to_idx[att]] = att_change_dir['default'][v][1] + 1
                    else:
                      # For location, just use original label space - NOTE: missing this caused an issue for the "location" attribute in our original submission, but not the "h_location" attribute
                      entity_anns[ent][0][att_to_idx[att]] = v
                      entity_anns[ent][1][att_to_idx[att]] = v
                      entity_anns[ent_verb][0][att_to_idx[att]] = v
                      entity_anns[ent_verb][1][att_to_idx[att]] = v


                entities = list(set(entities))
                all_entities = all_entities.union(set(entities))
                for ent in entities:
                  states_ex = {}
                  states_ex['example_id'] = ex_2s['example_id'] + '-%s-%s-%s' % (str(s_idx), str(i), ent)
                  states_ex['base_id'] = ex['example_id']
                  states_ex['sentence_idx'] = i
                  states_ex['entity'] = ent
                  states_ex['sentence'] = ex['sentences'][i]
                  states_ex['preconditions'] = entity_anns[ent][0]
                  states_ex['effects'] = entity_anns[ent][1]

                  ent_sent_examples[(ent, i)] = states_ex

            ex_2s['stories'][s_idx]['entities'] = [None for _ in range(len(all_entities))] # entity-story data: preconditions, effects, etc.
            for ei, ent in enumerate(all_entities):
              ent_ex = {}
              ent_ex['example_id'] = ex_2s['example_id'] + '-' + str(s_idx) + '-' + ent
              ent_ex['base_id'] = ex_2s['example_id'] + '-' + str(s_idx)
              ent_ex['sentences'] = ex['sentences']
              ent_ex['entity'] = ent
              ent_ex['attributes'] = np.zeros((max_story_length, len(att_to_idx)))
              ent_ex['preconditions'] = np.zeros((max_story_length, len(att_to_idx)))
              ent_ex['effects'] = np.zeros((max_story_length, len(att_to_idx)))

              ent_ex['conflict_span'] = (0,0)
              ent_ex['conflict_span_onehot'] = np.zeros((max_story_length))
              ent_ex['plausible'] = 1
              if s_idx != ex_2s['label'] and ex_2s['label'] != -1:
                # print(ex_2s['confl_sents'])
                conflict_span = (max([s+1 for s in ex_2s['confl_sents'] if s < ex_2s['breakpoint']]), ex_2s['breakpoint']+1) 

                # Check if the entity has some nontrivial annotated states in the boundaries of the conflict span
                if (ent, conflict_span[0]-1) in ent_sent_examples:
                  for i, att in enumerate(att_default_values):
                    if (ent_sent_examples[(ent, conflict_span[0]-1)]['preconditions'][i] != att_default_values[att] or ent_sent_examples[(ent, conflict_span[0]-1)]['effects'][i] != att_default_values[att]):
                      ent_ex['conflict_span'] = conflict_span
                      ent_ex['plausible'] = 0
                if (ent, conflict_span[1]-1) in ent_sent_examples:
                  for i, att in enumerate(att_default_values):
                    if (ent_sent_examples[(ent, conflict_span[1]-1)]['preconditions'][i] != att_default_values[att] or ent_sent_examples[(ent, conflict_span[1]-1)]['effects'][i] != att_default_values[att]):
                      ent_ex['conflict_span'] = conflict_span
                      ent_ex['plausible'] = 0

              for cs in ent_ex['conflict_span']:
                if cs > 0:
                  ent_ex['conflict_span_onehot'][cs-1] = 1

              # Get binary label for each span of text as well (for alternative formulation)
              ent_ex['span_labels'] = np.zeros((max_story_length * (max_story_length - 1) // 2))
              if s_idx == 1 - ex_2s['label']: # If this is the implausible choice
                span_idx = 0
                for s2 in range(1, len(ex['sentences'])):
                  for s1 in range(s2):
                    # print(ex['confl_pairs'])
                    for p1, p2 in ex_2s['confl_pairs']:
                      if s1 <= p1 and s2 >= p2:
                        # Check if the entity has some nontrivial annotated states in the boundaries of the conflict span
                        if (ent, p1) in ent_sent_examples:
                          for i, att in enumerate(att_default_values):
                            if ent_sent_examples[(ent, p1)]['preconditions'][i] != att_default_values[att] or ent_sent_examples[(ent, p1)]['effects'][i] != att_default_values[att]:
                              ent_ex['span_labels'][span_idx] = 1
                        if (ent, p2) in ent_sent_examples:
                          for i, att in enumerate(att_default_values):
                            if ent_sent_examples[(ent, p2)]['preconditions'][i] != att_default_values[att] or ent_sent_examples[(ent, p2)]['effects'][i] != att_default_values[att]:
                              ent_ex['span_labels'][span_idx] = 1
                    span_idx += 1

              for i in range(ex_2s['length']):
                if (ent, i) in ent_sent_examples:
                  ent_ex['preconditions'][i,:] = ent_sent_examples[(ent, i)]['preconditions']
                  ent_ex['effects'][i,:] = ent_sent_examples[(ent, i)]['effects']
                  for j, att in enumerate(att_default_values):
                    if ent_ex['preconditions'][i,j] != att_default_values[att] or ent_ex['effects'][i,j] != att_default_values[att]:
                      ent_ex['attributes'][i,j] = 1
              ex_2s['stories'][s_idx]['entities'][ei] = ent_ex
    tiered_dataset=dataset
    seq_length = 16 # Max sequence length to pad to
    tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)
    tiered_tensor_dataset = {}
    max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
    for p in tiered_dataset:
      tiered_tensor_dataset[p] = get_tensor_dataset_tiered(tiered_dataset[p], max_story_length, add_segment_ids=True)



### Train Models

#### Configure Hyperparameters
We will perform grid search over (batch size, learning rate). Configure the training sub-task, search space and set the maximum number of training epochs here. Currently configured for re-training the best RoBERTa-based model instance. Read code comments for more information.

**Additional configuration options:**
* Change the `generate_learning_curve` variable to `True` to generate data for training curves in the style presented in the paper.
* You may ablate the input to the Conflict Detector based on a few pre-defined ablation modes. To do so, change the `ablation` variable based on the comments in the code.

In [21]:
from www.dataset.ann import att_to_idx, att_to_num_classes, att_types

subtask = 'cloze'
batch_sizes = [config_batch_size]
learning_rates = [config_lr]
epochs = config_epochs
eval_batch_size = 16
generate_learning_curve = True # Generate data for training curve figure in TRIP paper

num_state_labels = {}
for att in att_to_idx:
  if att_types[att] == 'default':
    num_state_labels[att_to_idx[att]] = 3
  else:
    num_state_labels[att_to_idx[att]] = att_to_num_classes[att] # Location attributes fall into this since they don't have well-define pre- and post-condition yet

# Ablation options:
# - attributes: skip attribute prediction phase
# - embeddings: DON'T input contextual embeddings to conflict detector
# - states: DON'T input states to conflict detector
# - states-labels: in states input to conflict detector, include predicted labels
# - states-logits: in states input to conflict detector, include state logits (preferred)
# - states-teacher-forcing: train conflict detector on ground truth state labels (not predictions)
# - states-attention: re-weight input to conflict detector with weights conditioned on states representation
ablation = ['attributes', 'states-logits'] # This is the default mode presented in the paper

#### Perform Grid Search

Perform hyperparameter tuning to find the best story classification model.


In [22]:
print(len(tiered_dataset['train']))

799


In [23]:
batch_sizes

[1]

In [24]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.train import train_epoch_tiered
from www.model.eval import evaluate_tiered, save_results, save_preds, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict, get_model_dir
from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes
import shutil
import pandas as pd

seed_val = 22 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# We'll keep the validation data here with a constant eval batch size
dev_sampler = SequentialSampler(tiered_tensor_dataset['dev'])
dev_dataloader = DataLoader(tiered_tensor_dataset['dev'], sampler=dev_sampler, batch_size=eval_batch_size)
dev_dataset_name = subtask + '_%s_dev'
dev_ids = [ex['example_id'] for ex in tiered_dataset['dev']]

all_losses = []
param_combos = []
combo_names = []
all_val_objs = []
output_dirs = []
best_obj = 0.0
best_model = '<none>'
best_dir = ''
best_obj2 = 0.0
best_model2 = '<none>'
best_dir2 = ''

print('Beginning grid search for the %s sub-task over %s parameter combination(s)!' % (subtask, str(len(batch_sizes) * len(learning_rates))))
for bs in batch_sizes:
  for lr in learning_rates:
    print('\nTRAINING MODEL: bs=%s, lr=%s' % (str(bs), str(lr)))

    loss_values = []
    obj_values = []

    # Set up training dataset with new batch size
    train_sampler = RandomSampler(tiered_tensor_dataset['train'])
    train_dataloader = DataLoader(tiered_tensor_dataset['train'], sampler=train_sampler, batch_size=bs)

    # Set up model
    config = config_class.from_pretrained(model_name,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))    
    emb = emb_class.from_pretrained(model_name,
                                          config=config,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))    
    if torch.cuda.is_available():
      emb.cuda()
    device = emb.device
    max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
    model = TieredModelPipeline(emb, max_story_length, len(att_to_num_classes), num_state_labels,
                                config_class, model_name, device, 
                                ablation=ablation, loss_weights=loss_weights).to(device)

    # Set up optimizer
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

    train_lc_data = []
    val_lc_data = []
    for epoch in range(epochs):
      # Train the model for one epoch
      print('[%s] Beginning epoch...' % str(epoch))

      epoch_loss, _ = train_epoch_tiered(model, optimizer, train_dataloader, device, seg_mode=False, 
                                         build_learning_curves=generate_learning_curve, val_dataloader=dev_dataloader, 
                                         train_lc_data=train_lc_data, val_lc_data=val_lc_data)
      
      # Save loss
      loss_values.append(epoch_loss)

      # Validate on dev set
      validation_results = evaluate_tiered(model, dev_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
      metr_attr, all_pred_atts, all_atts, \
      metr_prec, all_pred_prec, all_prec, \
      metr_eff, all_pred_eff, all_eff, \
      metr_conflicts, all_pred_conflicts, all_conflicts, \
      metr_stories, all_pred_stories, all_stories, explanations = validation_results[:16]
      explanations = add_entity_attribute_labels(explanations, tiered_dataset['dev'], list(att_to_num_classes.keys()))

      print('[%s] Validation results:' % str(epoch))
      print('[%s] Preconditions:' % str(epoch))
      print_dict(metr_prec)
      print('[%s] Effects:' % str(epoch))
      print_dict(metr_eff)
      print('[%s] Conflicts:' % str(epoch))
      print_dict(metr_conflicts)
      print('[%s] Stories:' % str(epoch))
      print_dict(metr_stories)

      # Save accuracy - want to maximize verifiability of tiered predictions
      ver = metr_stories['verifiability']
      acc = metr_stories['accuracy']
      obj_values.append(ver)
      
      # Save model checkpoint
      print('[%s] Saving model checkpoint...' % str(epoch))
      model_param_str = get_model_dir(model_name.replace('/', '-'), subtask, bs, lr, epoch) + '_' +  '-'.join([str(lw) for lw in loss_weights]) +  '_tiered_pipeline_lc'
      if train_spans:
        model_param_str += 'spans'
      if len(model.ablation) > 0:
        model_param_str += '_ablate_'
        model_param_str += '_'.join(model.ablation)
      output_dir = os.path.join(DRIVE_PATH, 'saved_models', model_param_str)
      output_dirs.append(output_dir)
      if not os.path.exists(output_dir):
        os.makedirs(output_dir)

      save_results(metr_attr, output_dir, dev_dataset_name % 'attributes')
      save_results(metr_prec, output_dir, dev_dataset_name % 'preconditions')
      save_results(metr_eff, output_dir, dev_dataset_name % 'effects')
      save_results(metr_conflicts, output_dir, dev_dataset_name % 'conflicts')
      save_results(metr_stories, output_dir, dev_dataset_name % 'stories')
      save_results(explanations, output_dir, dev_dataset_name % 'explanations')

      # Just save story preds
      save_preds(dev_ids, all_stories, all_pred_stories, output_dir, dev_dataset_name % 'stories')

      emb = emb.module if hasattr(emb, 'module') else emb
      emb.save_pretrained(output_dir)
      torch.save(model, os.path.join(output_dir, 'classifiers.pth'))
      tokenizer.save_vocabulary(output_dir)

      if ver > best_obj:
        best_obj = ver
        best_model = model_param_str
        best_dir = output_dir
      if acc > best_obj2:
        best_obj2 = acc
        best_model2 = model_param_str
        best_dir2 = output_dir        

      for od in output_dirs:
        if od != best_dir and od != best_dir2 and os.path.exists(od):
          shutil.rmtree(od)

      print('[%s] Finished epoch.' % str(epoch))

    all_losses.append(loss_values)
    all_val_objs.append(obj_values)
    param_combos.append((bs, lr))
    combo_names.append('bs=%s, lr=%s' % (str(bs), str(lr)))

print('Finished grid search! :)')
print('Best validation *verifiability* %s from model %s.' % (str(best_obj), best_model))
print('Best validation *accuracy* %s from model %s.' % (str(best_obj2), best_model2))

if generate_learning_curve:
  print('Saving learning curve data...')
  train_lc_data = [subrecord for record in train_lc_data for subrecord in record] # flatten
  val_lc_data = [subrecord for record in val_lc_data for subrecord in record] # flatten

  train_lc_data = pd.DataFrame(train_lc_data)
  print(os.path.join(best_dir if best_dir != '<none>' else best_dir2, 'learning_curve_data_train.csv'))
  train_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_train.csv'), index=False)
  val_lc_data = pd.DataFrame(val_lc_data)
  val_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_val.csv'), index=False)
  print('Learning curve data saved. %s rows saved for training, %s rows saved for validation.' % (str(len(train_lc_data.index)), str(len(val_lc_data.index))))

Beginning grid search for the cloze sub-task over 1 parameter combination(s)!

TRAINING MODEL: bs=1, lr=1e-05


[                                                                        ]   0%

[0] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[0] Validation results:
[0] Preconditions:
{
  accuracy: 
    0.9931437444543034,
  f1: 
    0.22141924956196202,
  accuracy_0: 
    0.9951020408163266,
  f1_0: 
    0.33251500272776874,
  accuracy_1: 
    0.9996273291925466,
  f1_1: 
    0.6644556163617592,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9749068322981367,
  f1_5: 
    0.13124384434240646,
  accuracy_6: 
    0.9841348713398402,
  f1_6: 
    0.6284467292912538,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9921472937000887,
  f1_8: 
    0.3320193899442502,
  accuracy_9: 
    0.9857675244010647,
  f1_9: 
    0.6321814703817625,
  accuracy_10: 
    0.9941082519964508,
  f1_10: 
    0.33234847406141727,
  accuracy_11: 
    0.9969476486

[                                                                        ]   0%

[0] Finished epoch.
[1] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[1] Validation results:
[1] Preconditions:
{
  accuracy: 
    0.9939232475598935,
  f1: 
    0.269721190873293,
  accuracy_0: 
    0.9951020408163266,
  f1_0: 
    0.33251500272776874,
  accuracy_1: 
    0.9996273291925466,
  f1_1: 
    0.6644562779436248,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.980452528837622,
  f1_5: 
    0.22924285727881497,
  accuracy_6: 
    0.9867613132209405,
  f1_6: 
    0.6339576851037935,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9944454303460515,
  f1_8: 
    0.49386101646031805,
  accuracy_9: 
    0.9886069210292813,
  f1_9: 
    0.638113791762034,
  accuracy_10: 
    0.9941082519964508,
  f1_10: 
    0.33234847406141727,
  accuracy_11: 
    0.9969476486246

[                                                                        ]   0%

[1] Finished epoch.
[2] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[2] Validation results:
[2] Preconditions:
{
  accuracy: 
    0.9944236912156167,
  f1: 
    0.30264904960764183,
  accuracy_0: 
    0.9951020408163266,
  f1_0: 
    0.33251500272776874,
  accuracy_1: 
    0.999600709849157,
  f1_1: 
    0.6642983121964052,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.981561668145519,
  f1_5: 
    0.2349934915106866,
  accuracy_6: 
    0.9857763975155279,
  f1_6: 
    0.6314756389982349,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9959893522626442,
  f1_8: 
    0.5620784551812226,
  accuracy_9: 
    0.9878527062999113,
  f1_9: 
    0.6359564928247222,
  accuracy_10: 
    0.9960337178349601,
  f1_10: 
    0.524785832706712,
  accuracy_11: 
    0.996947648624667

[                                                                        ]   0%

[2] Finished epoch.
[3] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[3] Validation results:
[3] Preconditions:
{
  accuracy: 
    0.9946211180124224,
  f1: 
    0.35726352496637964,
  accuracy_0: 
    0.9951020408163266,
  f1_0: 
    0.33251500272776874,
  accuracy_1: 
    0.9996805678793257,
  f1_1: 
    0.6647706475353331,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9819875776397515,
  f1_5: 
    0.2758555677546053,
  accuracy_6: 
    0.9862555456965395,
  f1_6: 
    0.6336836367646085,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9969032830523514,
  f1_8: 
    0.5953346979884918,
  accuracy_9: 
    0.9878349600709849,
  f1_9: 
    0.6368114400919253,
  accuracy_10: 
    0.9962821650399291,
  f1_10: 
    0.5472702227393053,
  accuracy_11: 
    0.996947648624

[                                                                        ]   0%

[3] Finished epoch.
[4] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[4] Validation results:
[4] Preconditions:
{
  accuracy: 
    0.9949272404614019,
  f1: 
    0.3591101140708064,
  accuracy_0: 
    0.996681455190772,
  f1_0: 
    0.49623559795341304,
  accuracy_1: 
    0.9996184560780834,
  f1_1: 
    0.664403843120053,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9827506654835847,
  f1_5: 
    0.2791196840685756,
  accuracy_6: 
    0.9870807453416149,
  f1_6: 
    0.6353171532800164,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9966903283052352,
  f1_8: 
    0.5892119148784073,
  accuracy_9: 
    0.9888464951197871,
  f1_9: 
    0.6392451032500026,
  accuracy_10: 
    0.9961756876663709,
  f1_10: 
    0.5424600964380273,
  accuracy_11: 
    0.996947648624667

[                                                                        ]   0%

[4] Finished epoch.
[5] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[5] Validation results:
[5] Preconditions:
{
  accuracy: 
    0.9949170363797692,
  f1: 
    0.466601179096695,
  accuracy_0: 
    0.9975510204081632,
  f1_0: 
    0.5588133876798828,
  accuracy_1: 
    0.9996184560780834,
  f1_1: 
    0.6644031659960009,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9831854480922804,
  f1_5: 
    0.38423820065083475,
  accuracy_6: 
    0.9862910381543922,
  f1_6: 
    0.6333914990000288,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9965306122448979,
  f1_8: 
    0.584674989850725,
  accuracy_9: 
    0.9882342502218279,
  f1_9: 
    0.637871561737185,
  accuracy_10: 
    0.996335403726708,
  f1_10: 
    0.5555949992731832,
  accuracy_11: 
    0.9969476486246672,

[                                                                        ]   0%

[5] Finished epoch.
[6] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[6] Validation results:
[6] Preconditions:
{
  accuracy: 
    0.9950483584738243,
  f1: 
    0.4506916185008121,
  accuracy_0: 
    0.9973114463176574,
  f1_0: 
    0.5431911914005032,
  accuracy_1: 
    0.999600709849157,
  f1_1: 
    0.6642990203760476,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9827861579414374,
  f1_5: 
    0.36397872497028927,
  accuracy_6: 
    0.9872315882874889,
  f1_6: 
    0.6355161812723328,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9966370896184561,
  f1_8: 
    0.5874438529868354,
  accuracy_9: 
    0.9892280390417036,
  f1_9: 
    0.6399503623570415,
  accuracy_10: 
    0.9961756876663709,
  f1_10: 
    0.5402254581406759,
  accuracy_11: 
    0.99750665483584

[                                                                        ]   0%

[6] Finished epoch.
[7] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[7] Validation results:
[7] Preconditions:
{
  accuracy: 
    0.9949188110026619,
  f1: 
    0.5078437440551713,
  accuracy_0: 
    0.9978260869565218,
  f1_0: 
    0.5749156390418649,
  accuracy_1: 
    0.9996184560780834,
  f1_1: 
    0.664403843120053,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9822537710736469,
  f1_5: 
    0.41791718678432094,
  accuracy_6: 
    0.9865572315882875,
  f1_6: 
    0.6337198671566654,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9963265306122449,
  f1_8: 
    0.5794289993850253,
  accuracy_9: 
    0.9885093167701864,
  f1_9: 
    0.638036186236956,
  accuracy_10: 
    0.9963708961845608,
  f1_10: 
    0.5581356425474282,
  accuracy_11: 
    0.997701863354037

[                                                                        ]   0%

[7] Finished epoch.
[8] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[8] Validation results:
[8] Preconditions:
{
  accuracy: 
    0.9951712511091393,
  f1: 
    0.5275966033098849,
  accuracy_0: 
    0.997941437444543,
  f1_0: 
    0.5826323524948002,
  accuracy_1: 
    0.9995563442768411,
  f1_1: 
    0.6640372367812474,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9833362910381543,
  f1_5: 
    0.43618472215106124,
  accuracy_6: 
    0.9869476486246672,
  f1_6: 
    0.6347100409378332,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9966992014196984,
  f1_8: 
    0.5908403040918948,
  accuracy_9: 
    0.9887843833185448,
  f1_9: 
    0.638659567745795,
  accuracy_10: 
    0.9963708961845608,
  f1_10: 
    0.5583823366925348,
  accuracy_11: 
    0.997710736468500

[                                                                        ]   0%

[8] Finished epoch.
[9] Beginning epoch...


[########################################################################] 100%
[                                                                        ]   0%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:42s.
[9] Validation results:
[9] Preconditions:
{
  accuracy: 
    0.9951193433895297,
  f1: 
    0.5494736911309374,
  accuracy_0: 
    0.9966548358473825,
  f1_0: 
    0.49481103895686424,
  accuracy_1: 
    0.9996362023070098,
  f1_1: 
    0.6645087284112018,
  accuracy_2: 
    0.999148181011535,
  f1_2: 
    0.3331913030098593,
  accuracy_3: 
    0.9989352262644188,
  f1_3: 
    0.33315577651515155,
  accuracy_4: 
    0.9997692990239574,
  f1_4: 
    0.3332948787349029,
  accuracy_5: 
    0.9838775510204082,
  f1_5: 
    0.48073235464755076,
  accuracy_6: 
    0.9871073646850045,
  f1_6: 
    0.6353651975790507,
  accuracy_7: 
    0.9981277728482697,
  f1_7: 
    0.3330210030981383,
  accuracy_8: 
    0.9967879325643301,
  f1_8: 
    0.593063005137538,
  accuracy_9: 
    0.9889174800354925,
  f1_9: 
    0.6393310459064251,
  accuracy_10: 
    0.9961579414374445,
  f1_10: 
    0.5555334171375877,
  accuracy_11: 
    0.9975865128660

Delete all non-best model checkpoints:

In [31]:
import shutil

# Delete non-best model checkpoints
for od in output_dirs:
  if od != best_dir and od != best_dir2 and os.path.exists(od):
    shutil.rmtree(od)

In [25]:
print(best_dir)
print(best_dir2)

./saved_models/bert-large-uncased_cloze_1_1e-05_7_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits
./saved_models/bert-large-uncased_cloze_1_1e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits


### Test Models

Evaluate accuracy, consistency, and verifiability on the test set.

#### Load the Trained Model

Load the trained model we want to probe and select the appropriate dataset. Paths to the pre-trained models presented in the paper are already provided (download links are found in GitHub repo).

In [21]:
from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes, att_to_idx, att_types

#probe_model = eval_model_dir
#probe_model = os.path.join(DRIVE_PATH, 'saved_models', probe_model)
probe_model=best_dir
ablation = ['attributes', 'states-logits']

if 'cloze' in probe_model:
  subtask = 'cloze'
elif 'order' in probe_model:
  subtask = 'order'
  
if subtask == 'cloze':
  subtask_dataset = cloze_dataset_2s
elif subtask == 'order':
  subtask_dataset = order_dataset_2s

# Load the model
model = None
# model = torch.load(os.path.join(probe_model, 'classifiers.pth'), map_location=torch.device('cpu'))
model = torch.load(os.path.join(probe_model, 'classifiers.pth'))
if torch.cuda.is_available():
  model.cuda()
device = model.embedding.device

for layer in model.precondition_classifiers:
  layer.eval()
for layer in model.effect_classifiers:
  layer.eval()

In [25]:
probe_model

'./saved_models/roberta-large_cloze_1_1e-05_6_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits'

#### Test the Model

Run inference on the testing set of TRIP. Can simply edit the top-level `for` loop if you want to run inference on other partitions.

In [22]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from www.model.eval import evaluate_tiered, save_results, save_preds, list_comparison, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
metrics = [(accuracy_score, 'accuracy'), (precision_score, 'precision'), (recall_score, 'recall'), (f1_score, 'f1')]
import numpy as np
from www.utils import print_dict

print('Testing model: %s.' % probe_model)
train_spans = False
if train_spans:
  tiered_dataset = get_story_spans_2s(tiered_dataset, train_only=True)
  tiered_dataset['train'] = [ex for ex in tiered_dataset['train'] if ex['label'] != -1] # For now, ignore examples where both stories are plausible :(

# May alter this depending on which partition(s) you want to run inference on
for p in tiered_dataset:
  if p != 'test':
    continue

  p_dataset = tiered_dataset[p]
  p_tensor_dataset = tiered_tensor_dataset[p]
  p_sampler = SequentialSampler(p_tensor_dataset)
  p_dataloader = DataLoader(p_tensor_dataset, sampler=p_sampler, batch_size=16)
  dev_dataset_name = subtask + '_%s_' + p
  p_ids = [ex['example_id'] for ex in tiered_dataset[p]]

  # Get preds and metrics on this partition
  metr_attr, all_pred_atts, all_atts, \
  metr_prec, all_pred_prec, all_prec, \
  metr_eff, all_pred_eff, all_eff, \
  metr_conflicts, all_pred_conflicts, all_conflicts, \
  metr_stories, all_pred_stories, all_stories, explanations = evaluate_tiered(model, p_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
  explanations = add_entity_attribute_labels(explanations, tiered_dataset[p], list(att_to_num_classes.keys()))

  save_results(metr_attr, probe_model, dev_dataset_name % 'attributes')
  save_results(metr_prec, probe_model, dev_dataset_name % 'preconditions')
  save_results(metr_eff, probe_model, dev_dataset_name % 'effects')
  save_results(metr_conflicts, probe_model, dev_dataset_name % 'conflicts')
  save_results(metr_stories, probe_model, dev_dataset_name % 'stories')
  save_results(explanations, probe_model, dev_dataset_name % 'explanations')

  print('\nPARTITION: %s' % p)
  print('Stories:')
  print_dict(metr_stories)
  print('Conflicts:')
  print_dict(metr_conflicts)
  print('Preconditions:')
  print_dict(metr_prec)
  print('Effects:')
  print_dict(metr_eff)

[                                                                        ]   0%

Testing model: ./saved_models/roberta-large_cloze_1_1e-05_6_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits.
	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:01:59s.

PARTITION: test
Stories:
{
  accuracy: 
    0.7407407407407407,
  f1: 
    0.7406059917003013,
  verifiability: 
    0.09971509971509972,
}


Conflicts:
{
  accuracy: 
    0.9795519980705166,
  f1: 
    0.6476142695064642,
}


Preconditions:
{
  accuracy: 
    0.9960253395438581,
  f1: 
    0.41978344061428163,
  accuracy_0: 
    0.9979876090987202,
  f1_0: 
    0.540558272677171,
  accuracy_1: 
    0.99926136963174,
  f1_1: 
    0.6619444103712383,
  accuracy_2: 
    0.9991633880522769,
  f1_2: 
    0.33319383965767496,
  accuracy_3: 
    0.9993442771220549,
  f1_3: 
    0.3332240103442166,
  accuracy_4: 
    0.9998794072868147,
  f1_4: 
    0.33331323333584584,
  accuracy_5: 
    0.9889280815206741,
  f1_5: 
    0.349403873331483,
  accuracy_6: 
    0.9870136721988574,
  f1_6: 
    0.6203380977446721,
  accuracy_7: 
    0.9984624429068873,
  f1_7: 
    0.33307687665942554,
  accuracy_8: 
    0.9973620343990715,
  f1_8: 
    0.

#### Add Consistency Metric to Model Results
The intermediate conistency metric isn't included in the originally calculated metrics. This block adds the consistency metric to pre-existing model directory based on the tiered predictions. Generates a new `results_cloze_stories_final_[partition].json` file that includes the consistency metric.



In [23]:
import json
import os
best_dir=probe_model
model_directories = [best_dir]

partitions = ['dev', 'test']
expl_fname = 'results_cloze_explanations_%s.json'
endtask_fname = 'results_cloze_stories_%s.json'
endtask_fname_new = 'results_cloze_stories_final_%s.json'
for md in model_directories:
  for p in partitions:
    explanations = json.load(open(os.path.join(DRIVE_PATH, md, expl_fname % p), 'r'))
    #explanations = json.load(open(best_dir, 'r'))
    endtask_results = json.load(open(os.path.join(DRIVE_PATH, md, endtask_fname % p), 'r'))
    #endtask_results = json.load(open(best_dir, 'r'))
    consistent_preds = 0
    verifiable_preds = 0
    total = 0
    for expl in explanations:
      if expl['valid_explanation']:
        verifiable_preds += 1
      if expl['story_pred'] == expl['story_label']:
        if len(expl['conflict_pred']) == len(expl['conflict_label']) and expl['conflict_pred'][0] == expl['conflict_label'][0] and expl['conflict_pred'][1] == expl['conflict_label'][1]:
          expl['consistent'] = True
          consistent_preds += 1
        else:
          expl['consistent'] = False
      total += 1

    endtask_results['consistency'] = float(consistent_preds) / total
    print('Found %s consistent preds in %s (versus %s verifiable)' % (str(consistent_preds), p, str(verifiable_preds)))
    json.dump(explanations, open(os.path.join(DRIVE_PATH, md, (expl_fname % p).replace('explanations', 'explanations_consistency')), 'w'))
    json.dump(endtask_results, open(os.path.join(DRIVE_PATH, md, endtask_fname_new % p), 'w'))

Found 84 consistent preds in dev (versus 34 verifiable)
Found 73 consistent preds in test (versus 35 verifiable)



# Conversational Entailment (CE) Results

Code for the coherence experiments on CE.

In [None]:
if task_name != 'ce':
  raise ValueError('Please configure task_name in first cell to "ce" to run CE results!')

## Load Conversational Entailment Dataset

In [None]:
import xml.etree.ElementTree as ET
import pickle
cache_train = os.path.join(DRIVE_PATH, 'all_data/ConvEnt/ConvEnt_train_resplit.json')
cache_dev = os.path.join(DRIVE_PATH,'all_data/ConvEnt/ConvEnt_dev_resplit.json')
cache_test = os.path.join(DRIVE_PATH,'all_data/ConvEnt/ConvEnt_test_resplit.json')
ConvEnt_train = json.load(open(cache_train))
ConvEnt_dev = json.load(open(cache_dev))
ConvEnt_test = json.load(open(cache_test))

# Combine train and dev and do cross-validation
cache_folds = os.path.join(DRIVE_PATH,'all_data/ConvEnt/ConvEnt_folds.pkl') # Folds used for results presented in paper
ConvEnt_train = ConvEnt_train + ConvEnt_dev
train_sources = list(set([ex['dialog_source'] for ex in ConvEnt_train]))
print("Reserved %s dialog sources for training and validation." % len(train_sources))

no_folds = 8
if not os.path.exists(cache_folds):
  folds = []
  for k in range(no_folds):
    folds.append(np.random.choice(train_sources, size=5, replace=False))
    train_sources = [s for s in train_sources if s not in folds[-1]]
  assert len(train_sources) == 0
  print(folds)
  pickle.dump(folds, open(cache_folds, 'wb'))
else:
  folds = pickle.load(open(cache_folds, 'rb'))

Reserved 40 dialog sources for training and validation.


In [None]:
print('train examples:', len(ConvEnt_train))
print('dev examples:', len(ConvEnt_dev))
print('test examples:', len(ConvEnt_test))

train examples: 703
dev examples: 110
test examples: 172


## Featurize Conversational Entailment

In [None]:
from www.dataset.featurize import add_bert_features_ConvEnt, get_tensor_dataset
import pickle
seq_length = 128

ConvEnt_train = add_bert_features_ConvEnt(ConvEnt_train, tokenizer, seq_length, add_segment_ids=True)
ConvEnt_dev = add_bert_features_ConvEnt(ConvEnt_dev, tokenizer, seq_length, add_segment_ids=True)
ConvEnt_test = add_bert_features_ConvEnt(ConvEnt_test, tokenizer, seq_length, add_segment_ids=True)

ConvEnt_train_folds = [[] for _ in range(no_folds)]
ConvEnt_dev_folds = [[] for _ in range(no_folds)]
for k in range(no_folds):
  ConvEnt_train_folds[k] = [ex for ex in ConvEnt_train if ex['dialog_source'] not in folds[k]]
  ConvEnt_dev_folds[k] = [ex for ex in ConvEnt_train if ex['dialog_source'] in folds[k]]

  if debug:
    ConvEnt_train_folds[k] = ConvEnt_train_folds[k][:10]
    ConvEnt_dev_folds[k] = ConvEnt_dev_folds[k][:10]

if debug:
  ConvEnt_train = ConvEnt_train[:10]
  ConvEnt_dev = ConvEnt_dev[:10]
  ConvEnt_test = ConvEnt_test[:10]

ConvEnt_train_tensor = get_tensor_dataset(ConvEnt_train, label_key='label', add_segment_ids=True)
ConvEnt_test_tensor = get_tensor_dataset(ConvEnt_test, label_key='label', add_segment_ids=True)

# Training sets for each validation fold
ConvEnt_train_folds_tensor = [get_tensor_dataset(ConvEnt_train_folds[k], label_key='label', add_segment_ids=True) for k in range(no_folds)]
ConvEnt_dev_folds_tensor = [get_tensor_dataset(ConvEnt_dev_folds[k], label_key='label', add_segment_ids=True) for k in range(no_folds)]

In [None]:
print('train examples:', len(ConvEnt_train))
print('dev examples:', len(ConvEnt_dev))
print('test examples:', len(ConvEnt_test))

train examples: 10
dev examples: 10
test examples: 10


## Train Models on Conversational Entailment

### Train Models

#### Configure Hyperparameters

In [None]:
batch_sizes = [config_batch_size]
learning_rates = [config_lr]
epochs = config_epochs
eval_batch_size = 128

#### Grid Search and Cross-Validation

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.train import train_epoch
from www.model.eval import evaluate, save_results, save_preds
from sklearn.metrics import accuracy_score
from www.utils import print_dict, get_model_dir
from collections import Counter

seed_val = 22 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

assert len(batch_sizes) == 1
train_fold_sampler = [RandomSampler(f) for f in ConvEnt_train_folds_tensor]
train_fold_dataloader = [DataLoader(f, sampler=train_fold_sampler[i], batch_size=batch_sizes[0]) for i, f in enumerate(ConvEnt_train_folds_tensor)]

dev_fold_sampler = [SequentialSampler(f) for f in ConvEnt_dev_folds_tensor]
dev_fold_dataloader = [DataLoader(f, sampler=dev_fold_sampler[i], batch_size=eval_batch_size) for i, f in enumerate(ConvEnt_dev_folds_tensor)]

all_val_accs = Counter()
print('Beginning grid search for ConvEnt over %s parameter combination(s)!' % (str(len(batch_sizes) * len(learning_rates))))
for bs in batch_sizes:
  for lr in learning_rates:
    print('\nTRAINING MODEL: bs=%s, lr=%s' % (str(bs), str(lr)))

    for k in range(no_folds):
      print('Beginning fold %s/%s...' % (str(k+1), str(no_folds)))

      # Set up model
      if 'mnli' not in mode:
        model = model_class.from_pretrained(model_name, 
                                            cache_dir=os.path.join(DRIVE_PATH, 'cache'))
      else:
        config = config_class.from_pretrained(model_name.replace('-mnli',''),
                                        num_labels=3,
                                        cache_dir=os.path.join(DRIVE_PATH, 'cache'))
        model = model_class.from_pretrained(model_name, 
                                            config=config,
                                            cache_dir=os.path.join(DRIVE_PATH, 'cache'))
        config.num_labels = 2
        model.num_labels = 2
        model.classifier = cls_head_class(config=config) # Need to bring in a classification head for only 2 labels
    
      model.cuda()
      device = model.device 

      # Set up optimizer
      optimizer = AdamW(model.parameters(), lr=lr)
      total_steps = len(train_fold_dataloader[k]) * epochs
      scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

      for epoch in range(epochs):
        # Train the model for one epoch
        print('[%s] Beginning epoch...' % str(epoch))

        epoch_loss, _ = train_epoch(model, optimizer, train_fold_dataloader[k], device, seg_mode=True if 'roberta' not in mode else False)
        
        # Validate on dev set
        results, _, _ = evaluate(model, dev_fold_dataloader[k], device, [(accuracy_score, 'accuracy')], seg_mode=True if 'roberta' not in mode else False)
        print('[%s] Validation results:' % str(epoch))
        print_dict(results)

        # Save accuracy
        acc = results['accuracy']
        if (bs, lr, epoch) in all_val_accs:
          all_val_accs[(bs, lr, epoch)] += acc
        else:
          all_val_accs[(bs, lr, epoch)] = acc
        
      model.cpu()
      del model
      del optimizer
      del results
      del scheduler
      del total_steps

      print('[%s] Finished epoch.' % str(epoch))

for k in all_val_accs:
  all_val_accs[k] /= no_folds

print('Top performing param combos:')
print(all_val_accs.most_common(5))

save_fname = os.path.join(DRIVE_PATH, 'saved_models/%s_ConvEnt_xval_%s.pkl' % (model_name.replace('/','-'), '_'.join([str(lr) for lr in learning_rates])))
pickle.dump(all_val_accs, open(save_fname, 'wb'))

Beginning grid search for ConvEnt over 1 parameter combination(s)!

TRAINING MODEL: bs=1, lr=1e-05
Beginning fold 1/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.5,
}


[0] Finished epoch.
Beginning fold 2/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.7,
}


[0] Finished epoch.
Beginning fold 3/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.3,
}


[0] Finished epoch.
Beginning fold 4/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.6,
}


[0] Finished epoch.
Beginning fold 5/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.4,
}


[0] Finished epoch.
Beginning fold 6/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.3,
}


[0] Finished epoch.
Beginning fold 7/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.5,
}


[0] Finished epoch.
Beginning fold 8/8...


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.4,
}


[0] Finished epoch.
Top performing param combos:
[((1, 1e-05, 0), 0.46249999999999997)]


#### Re-Train Best Model from Cross-Validation

Re-train a model with the best parameters from the search above. If this isn't run directly after the above cell, replace `save_fname.split('/'[-1])` in `xval_fnames` with the name of the `pkl` file previously generated in the `saved_models` directory.

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.train import train_epoch
from www.model.eval import evaluate, save_results, save_preds
from sklearn.metrics import accuracy_score
from www.utils import print_dict, get_model_dir
from collections import Counter

# Re-train the model with the best parameters from the grid search/cross-validation (with all folds)
xval_fnames = []
xval_fnames.append(save_fname.split('/')[-1])

xval_results = Counter()
for fname in xval_fnames:
  xval_results += pickle.load(open(os.path.join(DRIVE_PATH, 'saved_models/', fname), 'rb'))

batch_size, learning_rate, epochs = xval_results.most_common(1)[0][0]
epochs += 1

# Set up model
if 'mnli' not in mode:
  model = model_class.from_pretrained(model_name, 
                                      cache_dir=os.path.join(DRIVE_PATH, 'cache'))
else:
  config = config_class.from_pretrained(model_name.replace('-mnli',''),
                                  num_labels=3,
                                  cache_dir=os.path.join(DRIVE_PATH, 'cache'))
  model = model_class.from_pretrained(model_name, 
                                      config=config,
                                      cache_dir=os.path.join(DRIVE_PATH, 'cache'))
  config.num_labels = 2
  model.num_labels = 2
  model.classifier = cls_head_class(config=config) # Need to bring in a classification head for only 2 labels

model.cuda()
device = model.device 

train_sampler = RandomSampler(ConvEnt_train_tensor)
train_dataloader = DataLoader(ConvEnt_train_tensor, sampler=train_sampler, batch_size=batch_size)

# Set up optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

for epoch in range(epochs):
  print('[%s] Beginning epoch...' % str(epoch))
  epoch_loss, _ = train_epoch(model, optimizer, train_dataloader, device, seg_mode=True if 'roberta' not in mode else False)

print('[%s] Saving model checkpoint...' % str(epoch))
model_param_str = get_model_dir(model_name.replace('/','-'), 'ConvEnt', batch_size, learning_rate, epoch) + '_xval'
output_dir = os.path.join(DRIVE_PATH, 'saved_models', model_param_str)
if not os.path.exists(output_dir):
  os.makedirs(output_dir)
model = model.module if hasattr(model, 'module') else model
model.save_pretrained(output_dir)
tokenizer.save_vocabulary(output_dir)

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

[0] Beginning epoch...
[0] Saving model checkpoint...


('drive/My Drive/Colab Notebooks/Research/TRIP_replication/saved_models/roberta-large_ConvEnt_1_1e-05_0_xval/vocab.json',
 'drive/My Drive/Colab Notebooks/Research/TRIP_replication/saved_models/roberta-large_ConvEnt_1_1e-05_0_xval/merges.txt')

## Test Models on Conversational Entailment

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.eval import evaluate, save_results, save_preds
from sklearn.metrics import accuracy_score
from www.utils import print_dict, get_model_dir

best_model = eval_model_dir


best_model = os.path.join(DRIVE_PATH, 'saved_models', best_model)

# Load the model
model = model_class.from_pretrained(best_model)
model.cuda()
device = model.device

# Select appropriate dataset
if 'cloze' in best_model:
  subtask = 'cloze'
elif 'order' in best_model:
  subtask = 'order'

test_sampler = SequentialSampler(ConvEnt_test_tensor)
test_dataloader = DataLoader(ConvEnt_test_tensor, sampler=test_sampler, batch_size=128)
test_dataset_name = '%s_%s' % ('ConvEnt', 'test')
test_ids = [str(ex['example_id']) for ex in ConvEnt_test]

print('Testing model: %s.' % best_model.split('/')[-1])

results, preds, labels = evaluate(model, test_dataloader, device, [(accuracy_score, 'accuracy')], seg_mode=True if 'roberta' not in mode else False)
save_results(results, best_model, test_dataset_name)
save_preds(test_ids, labels, preds, best_model, test_dataset_name)

print('Results (%s):' % p)
print_dict(results)

Testing model: roberta-large_ConvEnt_1_1e-05_0_xval.
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
Results (test):
{
  accuracy: 
    0.7,
}




## Coherence Checks on Conversational Entailment

### Load and Featurize Span Data

In [None]:
from www.dataset.featurize import add_bert_features_ConvEnt, get_tensor_dataset
from www.dataset.prepro import get_ConvEnt_spans
import pickle
seq_length = 128

merged_file = os.path.join(DRIVE_PATH, 'all_data/ConvEnt/ConvEnt_test_annotation_merged2.json')
ConvEnt_test = json.load(open(merged_file))

ConvEnt_test = add_bert_features_ConvEnt(ConvEnt_test, tokenizer, seq_length, add_segment_ids=True)

if debug:
  ConvEnt_test = ConvEnt_test[:10]

# Some of the annotated examples are no longer in the test set :(
# ConvEnt_test = [ex for ex in ConvEnt_test if ex['id'] in test_ids]

# Make span versions of the datasets
ConvEnt_test_spans = get_ConvEnt_spans(ConvEnt_test)

# Add BERT features
ConvEnt_test_tensor = get_tensor_dataset(ConvEnt_test, label_key='label', add_segment_ids=True)
ConvEnt_test_spans_tensor = get_tensor_dataset(ConvEnt_test_spans, label_key='label', add_segment_ids=True)

### Load the Trained Model

Load the trained model we want to probe and select the appropriate dataset.

In [None]:
probe_model = eval_model_dir
probe_model = os.path.join(DRIVE_PATH, 'saved_models', probe_model)

# Load the model
model = model_class.from_pretrained(probe_model)
if torch.cuda.is_available():
  model.cuda()
device = model.device 

#### Load Trained Model's Base Predictions

For comparison, we also want the preds and labels for the previous level.

In [None]:
from www.model.eval import load_preds
from www.utils import print_dict

preds_base = {}
preds_base['test'] = load_preds(os.path.join(probe_model, 'preds_ConvEnt_test.tsv'))
print(preds_base['test'].keys())

dict_keys(['73', '74', '75', '76', '77', '78', '79', '80', '81', '82'])


### Check a Model

Will print out strict and lenient coherence metrics.

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from www.model.eval import evaluate, save_results, save_preds, list_comparison
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
metrics = [(accuracy_score, 'accuracy'), (precision_score, 'precision'), (recall_score, 'recall'), (f1_score, 'f1')]
import numpy as np
from www.utils import print_dict

def is_polarized(smax, thres):
  return (abs(smax[0] - smax[1]) >= thres)

print('Testing model: %s.' % probe_model)

all_results = {}
p = 'test'

p_dataset = ConvEnt_test_spans
p_tensor_dataset = ConvEnt_test_spans_tensor
p_sampler = SequentialSampler(p_tensor_dataset)
p_dataloader = DataLoader(p_tensor_dataset, sampler=p_sampler, batch_size=512)
p_dataset_name = '%s_spans_%s' % ('ConvEnt', p)
p_dataset_name_co = '%s_consistent_%s' % ('ConvEnt', p)
p_dataset_name_bp = '%s_breakpoints_%s' % ('ConvEnt', p)
p_dataset_name_ev = '%s_evidence_%s' % ('ConvEnt', p)
p_dataset_name_coh = '%s_coherent_%s' % ('ConvEnt', p)
p_ids = [str(ex['example_id']) for ex in ConvEnt_test_spans]
p_labels = [ex['label'] for ex in ConvEnt_test_spans]

# Get span preds and save metrics
results, preds, labels = evaluate(model, p_dataloader, device, metrics, seg_mode=True if 'roberta' not in mode else False)
save_results(results, probe_model, p_dataset_name)
save_preds(p_ids, labels, preds, probe_model, p_dataset_name)

# Convert substory preds into breakpoint preds for each example
ids_base = [str(ex['example_id']) for ex in ConvEnt_test]

id_to_pred = {k: v for k,v in zip(p_ids, preds)}
id_to_label = {k: v for k,v in zip(p_ids, p_labels)}

preds_entailment = []
labels_entailment = []
preds_consistent = []
preds_breakpoint = []
labels_breakpoint = []
preds_evidence = []
labels_evidence = []    
span_accuracies = []
span_accuracies_strict = []
preds_coherent = []

for i, exid in enumerate(ids_base):
  ex = ConvEnt_test[i]
  ex['length'] = len(ex['turns'])

  label_entailment = preds_base[p][exid]['label']
  pred_entailment = preds_base[p][exid]['pred']
  labels_entailment.append(label_entailment)
  preds_entailment.append(pred_entailment)

  # Get ground truth breakpoint and evidence
  label_breakpoint = ex['conflict_pair'][1] if ex['conflict_pair'] is not None and len(ex['conflict_pair']) > 0 else 0
  labels_breakpoint.append(label_breakpoint)
  if label_breakpoint > 0:
    label_ev = ex['conflict_pair'][0]
  else:
    label_ev = -1
  labels_evidence.append(label_ev)

  # Check consistency - any span that entails the hypothesis' superspans should also entail
  pred_consistent = True
  span_accuracy = 0.0
  span_accuracy_strict = 0.0
  pred_coherent = True
  
  no_spans = 0
  for sp1 in range(ex['length']):
    if not pred_consistent:
      break

    for sp2 in range(sp1, ex['length']):
      if not pred_consistent:
        break

      span_pred = id_to_pred[exid + '-sp%s:%s' % (str(sp1), str(sp2))]
      span_label = id_to_label[exid + '-sp%s:%s' % (str(sp1), str(sp2))]

      if span_pred == span_label:
        span_accuracy += 1.0
        if label_entailment == pred_entailment:
            span_accuracy_strict += 1.0
      else:
        pred_coherent = False
      no_spans += 1
      # print('%s:%s\t%s\t(%s, %s)' % (str(sp1), str(sp2), str(span_pred), str(span_prob[0]), str(span_prob[1])))      

      if span_pred == 1:
        if pred_entailment == 1:
          for sp3 in range(sp1+1):
            if not pred_consistent:
              break

            for sp4 in range(sp2, ex['length']):
              if not pred_consistent:
                break

              sspan_pred = id_to_pred[exid + '-sp%s:%s' % (str(sp3), str(sp4))]

              if sspan_pred == 0:
                pred_consistent = False
                break
        elif pred_entailment == 0:
          pred_consistent = False

  preds_consistent.append(1 if pred_consistent else 0)
  span_accuracies.append(span_accuracy / no_spans)
  span_accuracies_strict.append(span_accuracy_strict / no_spans)
  preds_coherent.append(1 if pred_coherent else 0)

  # Check pred. breakpoint (verifiability) - will be first sentence where the model prediction becomes polarized, i.e., confidence > threshold
  pred_breakpoint = 0 # For now, 0 means -1, i.e., stories are entirely plausible - this shouldn't happen but it will (inconsistent?)
  for ss in range(1, ex['length']):
    if id_to_pred[exid + '-sp%s:%s' % (str(0), str(ss))] == 1:
      pred_breakpoint = ss
      break
  preds_breakpoint.append(pred_breakpoint)

  # Check pred. evidence (verifiability)
  if pred_breakpoint > 0:
    pred_evidence = -1 
    for ss in range(0, pred_breakpoint+1):
      if id_to_pred[exid + '-sp%s:%s' % (str(0), str(ss))] == 1:
        pred_evidence = ss
  else:
    pred_evidence = -1 # This should never happen - it would be inconsistent if it did
  preds_evidence.append(pred_evidence)

# Calculate tiered accuracy for model
acc = 0
acc_con = 0
acc_con_vbp = 0
acc_con_vbp_vev = 0
no_ex = len(ids_base)
for p_plaus, l_plaus, con, p_bp, l_bp, p_ev, l_ev in zip(preds_entailment, labels_entailment, preds_consistent, preds_breakpoint, labels_breakpoint, preds_evidence, labels_evidence):
  # Accuracy
  if p_plaus == l_plaus:
    acc += 1
    
    # Consistency
    if con == 1:
      acc_con += 1
    
      # Verifiability (breakpoint)
      if p_bp == l_bp:
        acc_con_vbp += 1

        # Verifiability (evidence)
        if p_ev == l_ev:
          acc_con_vbp_vev += 1

acc /= no_ex
acc_con /= no_ex
acc_con_vbp /= no_ex
acc_con_vbp_vev /= no_ex

# all_results['acc'] = acc
# all_results['acc_con'] = acc_con
# all_results['acc_con_vbp'] = acc_con_vbp
# all_results['acc_con_vbp_vev'] = acc_con_vbp_vev
# all_results['span_accuracy'] = np.mean(span_accuracies)

all_results['lenient_coherence'] = np.mean(span_accuracies_strict)
all_results['strict_coherence'] = np.mean(preds_coherent)

best_preds_entailment = preds_entailment
best_preds_consistent = preds_consistent
best_preds_breakpoint = preds_breakpoint
best_preds_evidence = preds_evidence
best_preds_coherent = preds_coherent
    
print('\nPARTITION: %s' % p)
print_dict(all_results)

# Save preds for breakpoint and evidence
save_preds(ids_base, np.array(labels_breakpoint), best_preds_breakpoint, probe_model, p_dataset_name_bp)
save_preds(ids_base, np.array(labels_evidence), best_preds_evidence, probe_model, p_dataset_name_ev)
save_preds(ids_base, np.array([1 for p in best_preds_coherent]), best_preds_coherent, probe_model, p_dataset_name_coh)

p_dataset_name_agg = '%s_tiers_agg_nostates_lenient_%s' % ('ConvEnt', p)
save_results(all_results, probe_model, p_dataset_name_agg)

Testing model: drive/My Drive/Colab Notebooks/Research/TRIP_replication/saved_models/roberta-large_ConvEnt_1_1e-05_0_xval.
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:02s.

PARTITION: test
{
  lenient_coherence: 
    0.4269000933706816,
  strict_coherence: 
    0.4,
}




  _warn_prf(average, modifier, msg_start, len(result))


# ART Results

Code for the coherence experiments on ART.

In [None]:
if task_name != 'art':
  raise ValueError('Please configure task_name in first cell to "art" to run ART results!')

## Load ART dataset

ART is originally gathered from [HuggingFace datasets](https://huggingface.co/docs/datasets/), but we added some of our own annotations for the coherence evaluation.

In [None]:
import os
fname = os.path.join(DRIVE_PATH, 'all_data/ART/art.json')
with open(fname, 'r') as f:
  art = json.load(f)

## Train Models on ART

### Featurize ART

In [None]:
from www.dataset.featurize import add_bert_features_art, get_tensor_dataset
seq_length = 32

for p in art:
  for i in range(len(art[p])):
    art[p][i]['label'] -= 1 # Do this so labels start at 0

  if debug:
    # Take 20 examples that we've annotated as the debug set so we can run the coherence metrics
    merged_file = os.path.join(DRIVE_PATH, 'all_data/ART/ART_test_rand200_annotation_merged2.json')
    ann_ids = [ex['id'] for ex in json.load(open(merged_file))]
     
    if p == 'train':
      art[p] = art[p][:20]
      art[p] = art[p][:20]
    elif p == 'val':
      art[p] = [ex for ex in art[p] if ex['id'] in ann_ids][:20]

art_tensor = {}
for p in art:
  art[p] = add_bert_features_art(art[p], tokenizer, seq_length)
  art_tensor[p] = get_tensor_dataset(art[p], label_key='label')

### Train Models

Train models on ART. Note that ART's test set is not public, so we cannot test the model (unless we submit to their [leaderboard](https://leaderboard.allenai.org/anli/submissions/public)).

#### Configure Hyperparameters

In [None]:
batch_sizes = [config_batch_size]
learning_rates = [config_lr]
epochs = config_epochs
eval_batch_size = 128

#### Grid Search

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup, RobertaForMultipleChoice, BertForMultipleChoice, RobertaConfig, BertConfig
from www.model.train import train_epoch
from www.model.eval import evaluate, save_results, save_preds
from sklearn.metrics import accuracy_score
from www.utils import print_dict, get_model_dir

seed_val = 22 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# We'll keep the validation data here with a constant eval batch size
dev_sampler = SequentialSampler(art_tensor['val'])
dev_dataloader = DataLoader(art_tensor['val'], sampler=dev_sampler, batch_size=eval_batch_size)
dev_dataset_name = 'art_val'
dev_ids = [str(ex['example_id']) for ex in art['val']]

all_losses = []
param_combos = []
combo_names = []
all_val_accs = []
output_dirs = []
best_acc = 0.0

print('Beginning grid search for ART over %s parameter combination(s)!' % (str(len(batch_sizes) * len(learning_rates))))
for bs in batch_sizes:
  for lr in learning_rates:
    print('\nTRAINING MODEL: bs=%s, lr=%s' % (str(bs), str(lr)))

    loss_values = []
    acc_values = []

    # Set up training dataset with new batch size
    train_sampler = RandomSampler(art_tensor['train'])
    train_dataloader = DataLoader(art_tensor['train'], sampler=train_sampler, batch_size=bs)

    # Set up model
    config = config_class.from_pretrained(model_name,
                                          num_labels=2,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))
    model = model_class.from_pretrained(model_name,
                                        config=config,
                                        cache_dir=os.path.join(DRIVE_PATH, 'cache'))

    model.cuda()
    device = model.device 

    # Set up optimizer
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

    for epoch in range(epochs):
      # Train the model for one epoch
      print('[%s] Beginning epoch...' % str(epoch))

      epoch_loss, _ = train_epoch(model, optimizer, train_dataloader, device)
      
      # Save loss
      loss_values.append(epoch_loss)

      # Validate on dev set
      results, preds, labels = evaluate(model, dev_dataloader, device, [(accuracy_score, 'accuracy')])
      print('[%s] Validation results:' % str(epoch))
      print_dict(results)

      # Save accuracy
      acc = results['accuracy']
      acc_values.append(acc)
      
      # Save model checkpoint
      print('[%s] Saving model checkpoint...' % str(epoch))
      model_param_str = get_model_dir(model_name.replace('/','-'), 'art', bs, lr, epoch)# + '_toy'
      output_dir = os.path.join(DRIVE_PATH, 'saved_models', model_param_str)
      output_dirs.append(output_dir)
      if not os.path.exists(output_dir):
        os.makedirs(output_dir)
      save_results(results, output_dir, dev_dataset_name)
      save_preds(dev_ids, labels, preds, output_dir, dev_dataset_name)
      model = model.module if hasattr(model, 'module') else model
      model.save_pretrained(output_dir)
      tokenizer.save_vocabulary(output_dir)

      if acc > best_acc:
        best_acc = acc
        best_model = model_param_str
        best_dir = output_dir

      print('[%s] Finished epoch.' % str(epoch))

    all_losses.append(loss_values)
    all_val_accs.append(acc_values)
    param_combos.append((bs, lr))
    combo_names.append('bs=%s, lr=%s' % (str(bs), str(lr)))

print('Finished grid search! :)')
print('Best validation accuracy %s from model %s.' % (best_acc, best_model))

Beginning grid search for ART over 1 parameter combination(s)!

TRAINING MODEL: bs=1, lr=1e-05


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForMultipleChoice: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for pre

[0] Beginning epoch...
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:00s.
[0] Validation results:
{
  accuracy: 
    0.7,
}


[0] Saving model checkpoint...
[0] Finished epoch.
Finished grid search! :)
Best validation accuracy 0.7 from model roberta-large_art_1_1e-05_0.


Delete non-best model checkpoints:

In [None]:
import shutil

# Delete non-best model checkpoints
for od in output_dirs:
  if od != best_dir and os.path.exists(od):
    shutil.rmtree(od)

## Coherence Checks on ART

### Load and Featurize Span Data

In [None]:
from www.dataset.featurize import add_bert_features_art, get_tensor_dataset
from www.dataset.prepro import get_art_spans
import pickle
seq_length = 128
  
merged_file = os.path.join(DRIVE_PATH, 'all_data/ART/ART_test_rand200_annotation_merged2.json')
art_anns = json.load(open(merged_file))

if debug:
  ann_ids = [ex['id'] for ex in art_anns]
  debug_ids = [ex['id'] for ex in art[p] if ex['id'] in ann_ids][:20]  
  art = [ex for ex in art_anns if ex['id'] in debug_ids]

# Make span versions of the datasets
art_spans = get_art_spans(art)

# Add BERT features
art = add_bert_features_art(art, tokenizer, seq_length, add_segment_ids=True)
art_spans = add_bert_features_art(art_spans, tokenizer, seq_length, add_segment_ids=True)

# Add BERT features
art_tensor = get_tensor_dataset(art, label_key='label', add_segment_ids=True)
art_spans_tensor = get_tensor_dataset(art_spans, label_key='label', add_segment_ids=True)

### Load the Trained Model

Load the trained model we want to probe and select the appropriate dataset.

In [None]:
probe_model = eval_model_dir
probe_model = os.path.join(DRIVE_PATH, 'saved_models', probe_model)
  
# Load the model
model = model_class.from_pretrained(probe_model)
if torch.cuda.is_available():
  model.cuda()
device = model.device 

#### Load Trained Model's Two-Story Classification Predictions

For comparison, we also want the preds and labels for the previous level.

In [None]:
from www.model.eval import load_preds
from www.utils import print_dict

preds_base = {}
preds_base['val'] = load_preds(os.path.join(probe_model, 'preds_art_val.tsv'))

### Calculate Coherence Metrics

As ART is a multiple-choice task, we will need to tune the confidence threshold $\rho$. This code will print out the strict and lenient coherence metrics, as well as the chosen $\rho$ (`best_threshold`).

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from www.model.eval import evaluate, save_results, save_preds, list_comparison
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
metrics = [(accuracy_score, 'accuracy'), (precision_score, 'precision'), (recall_score, 'recall'), (f1_score, 'f1')]
import numpy as np
from www.utils import print_dict

def is_polarized(smax, thres):
  return (abs(smax[0] - smax[1]) >= thres)

print('Testing model: %s.' % probe_model)

subtask = 'art'
p = 'val'
all_results = {}

p_dataset = art_spans
p_tensor_dataset = art_spans_tensor
p_sampler = SequentialSampler(p_tensor_dataset)
p_dataloader = DataLoader(p_tensor_dataset, sampler=p_sampler, batch_size=128)
p_dataset_name = '%s_spans_%s' % (subtask, p)
p_dataset_name_co = '%s_consistent_%s' % (subtask, p)
p_dataset_name_bp = '%s_breakpoints_%s' % (subtask, p)
p_dataset_name_ev = '%s_evidence_%s' % (subtask, p)
p_dataset_name_coh = '%s_coherence_%s' % (subtask, p)
p_dataset_name_subset = '%s_rand200_%s' % (subtask, p)
p_ids = [ex['example_id'] for ex in art_spans]
p_labels = [ex['label'] for ex in art_spans]

# Get span preds and save metrics
results, preds, labels, probs = evaluate(model, p_dataloader, device, metrics, seg_mode=True if 'roberta' not in mode else False, return_softmax=True)
save_results(results, probe_model, p_dataset_name)
save_preds(p_ids, labels, preds, probe_model, p_dataset_name)

# Convert substory preds into breakpoint preds for each example
ids_base = [ex['example_id'] for ex in art]

id_to_pred = {k: v for k,v in zip(p_ids, preds)}
id_to_prob = {k: v for k,v in zip(p_ids, probs)}
id_to_label = {k: v for k,v in zip(p_ids, p_labels)}

for metric_to_optimize in ['strict_coherence', 'lenient_coherence']:
  # Get results dict ready
  # all_results['acc'] = 0.0
  # all_results['acc_con'] = 0.0
  # all_results['acc_con_vbp'] = 0.0
  # all_results['acc_con_vbp_vev'] = 0.0
  # all_results['span_accuracy'] = 0.0
  all_results['lenient_coherence'] = 0.0
  all_results['strict_coherence'] = 0.0
  span_accuracy = 0.0
  span_accuracy_strict = 0.0
  no_spans = 0
  for threshold in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:

    preds_plausible = []
    labels_plausible = []
    preds_consistent = []
    preds_breakpoint = []
    labels_breakpoint = []
    preds_evidence = []
    labels_evidence = []    
    span_accuracies = []
    span_accuracies_strict = []
    preds_coherent = []

    for i, exid in enumerate(ids_base):
      ex = art[i]
      ex['length'] = 3

      label_plausible = preds_base[p][exid]['label']
      pred_plausible = preds_base[p][exid]['pred']
      labels_plausible.append(label_plausible)
      preds_plausible.append(pred_plausible)

      # Get ground truth breakpoint and evidence
      label_breakpoint = ex['conflict_pair'][1] if ex['conflict_pair'] is not None else 0
      labels_breakpoint.append(label_breakpoint)
      if label_breakpoint > 0:
        label_ev = ex['conflict_pair'][0]
      else:
        label_ev = -1
      labels_evidence.append(label_ev)

      # Check consistency - for every span we confidently choose story X, we should also confidently choose story X for any span containing it
      pred_consistent = True
      for sp1 in range(ex['length']-1):
        if not pred_consistent:
          break

        for sp2 in range(sp1+1, ex['length']):
          if not pred_consistent:
            break

          span_pred = int(id_to_pred[exid + '-sp%s:%s' % (str(sp1), str(sp2))])
          span_prob = id_to_prob[exid + '-sp%s:%s' % (str(sp1), str(sp2))]
          span_label = max(id_to_label[exid + '-sp%s:%s' % (str(sp1), str(sp2))] - 1, -1)

          span_pred3 = span_pred
          if not is_polarized(span_prob, threshold): # If not polarized, let's say -1
            span_pred3 = -1

          pred_coherent = True
          if span_pred3 == span_label:
            span_accuracy += 1.0
            if label_plausible == pred_plausible:
              span_accuracy_strict += 1.0
          else:
            pred_coherent = False
          no_spans += 1

          if is_polarized(span_prob, threshold):
            for sp3 in range(sp1+1):
              if not pred_consistent:
                break

              for sp4 in range(sp2, ex['length']):
                if not pred_consistent:
                  break

                sspan_pred = id_to_pred[exid + '-sp%s:%s' % (str(sp3), str(sp4))]
                sspan_prob = id_to_prob[exid + '-sp%s:%s' % (str(sp3), str(sp4))]

                if not is_polarized(sspan_prob, threshold) or sspan_pred != span_pred:
                  pred_consistent = False
                  break

      preds_consistent.append(1 if pred_consistent else 0)
      span_accuracies.append(span_accuracy / no_spans)
      span_accuracies_strict.append(span_accuracy_strict / no_spans)
      preds_coherent.append(1 if pred_coherent else 0)

      # Check pred. breakpoint (verifiability) - will be first sentence where the model prediction becomes polarized, i.e., confidence > threshold
      pred_breakpoint  = 0 # For now, 0 means -1, i.e., stories are entirely plausible - this shouldn't happen but it will (inconsistent?)
      for ss in range(1, ex['length']):
        if is_polarized(id_to_prob[exid + '-sp%s:%s' % (str(0), str(ss))], threshold):
          pred_breakpoint = ss
          break
      preds_breakpoint.append(pred_breakpoint)

      # Check pred. evidence (verifiability)
      if pred_breakpoint > 0:
        pred_evidence = -1 # Does this make sense for default value?
        for ss in range(0, pred_breakpoint):
          if is_polarized(id_to_prob[exid + '-sp%s:%s' % (str(ss), str(pred_breakpoint))], threshold):
            pred_evidence = ss
      else:
        pred_evidence = -1 # This should never happen - it would be inconsistent if it did?
      preds_evidence.append(pred_evidence)

    # Calculate tiered accuracy for model
    acc = 0
    acc_con = 0
    acc_con_vbp = 0
    acc_con_vbp_vev = 0
    no_ex = len(ids_base)
    for p_plaus, l_plaus, con, p_bp, l_bp, p_ev, l_ev in zip(preds_plausible, labels_plausible, preds_consistent, preds_breakpoint, labels_breakpoint, preds_evidence, labels_evidence):
      # Accuracy
      if p_plaus == l_plaus:
        acc += 1
        
        # Consistency
        if con == 1:
          acc_con += 1
        
          # Verifiability (breakpoint)
          if p_bp == l_bp:
            acc_con_vbp += 1

            # Verifiability (evidence)
            if p_ev == l_ev:
              acc_con_vbp_vev += 1

    acc /= no_ex
    acc_con /= no_ex
    acc_con_vbp /= no_ex
    acc_con_vbp_vev /= no_ex
    span_acc = np.mean(span_accuracies)
    span_acc_strict = np.mean(span_accuracies_strict)
    coherence = np.mean(preds_coherent)
    # if coherence > all_results['coherence']: # !!!! this line is important
    # if span_acc > all_results['span_accuracy']: # !!!! this line is important
    if span_acc_strict > all_results[metric_to_optimize]: # !!!! this line is important
      # print('new best: %s' % str(acc_con_vbp_vev))
      best_thres = threshold
      
      # all_results['acc'] = acc
      # all_results['acc_con'] = acc_con
      # all_results['acc_con_vbp'] = acc_con_vbp
      # all_results['acc_con_vbp_vev'] = acc_con_vbp_vev
      # all_results['span_accuracy'] = span_acc
      all_results['lenient_coherence'] = span_acc_strict
      all_results['strict_coherence'] = coherence

      best_preds_plausible = preds_plausible
      best_preds_consistent = preds_consistent
      best_preds_breakpoint = preds_breakpoint
      best_preds_evidence = preds_evidence
      best_preds_coherent = preds_coherent
      
  all_results['best_threshold'] = best_thres
  print('\nPARTITION: %s \t METRIC: %s' % (p, metric_to_optimize))
  print('chosen threshold: %s' % str(best_thres))
  print_dict(all_results)

  # Save results
  p_dataset_name_agg = '%s_%s_%s' % (subtask, metric_to_optimize, p)
  save_results(all_results, probe_model, p_dataset_name_agg)

Testing model: drive/My Drive/Colab Notebooks/Research/TRIP_replication/saved_models/roberta-large_art_1_1e-05_0.
	Beginning evaluation...
		Running prediction...
		Computing metrics...
	Finished evaluation in 0:00:03s.

PARTITION: val 	 METRIC: strict_coherence
chosen threshold: 1.0
{
  lenient_coherence: 
    0.18246427120454473,
  strict_coherence: 
    0.15,
  best_threshold: 
    1.0,
}



PARTITION: val 	 METRIC: lenient_coherence
chosen threshold: 1.0
{
  lenient_coherence: 
    0.18246427120454473,
  strict_coherence: 
    0.15,
  best_threshold: 
    1.0,
}




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
