In [7]:
import warnings
warnings.filterwarnings('ignore')

In [8]:
import os
import json
import sys
import torch
import random
import numpy as np
import spacy
#!pip install jsonlines

In [9]:
seed_val = 595 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [13]:
from www.utils import print_dict

partitions = ['train', 'dev', 'test']
subtasks = ['cloze', 'order']

# We can split the data into multiple json files later
data_file = './all_data/www.json'
with open(data_file, 'r') as f:
  dataset = json.load(f)

print('Preprocessed examples:')
for ex_idx in [0,1,5,10]:
  ex = dataset['dev'][list(dataset['dev'].keys())[ex_idx]]
  print_dict(ex)

Preprocessed examples:
{
  story_id: 
    13,
  worker_id: 
    A32W24TWSWXW,
  type: 
    None,
  idx: 
    None,
  aug: 
    False,
  actor: 
    John,
  location: 
    kitchen,
  objects: 
    cabinet, counter, knife, pan, potato, pizza,
  sentences: 
    [
      John was getting the snacks ready for the party.
      John opened the cabinet, took out a pan and put it on the counter.
      John opened the fridge and got out the pizza.
      John put the pizza on the pan and put them into the oven.
      John took a knife and cut the hot pizza in eight slices.
    ],
  length: 
    5,
  example_id: 
    13,
  plausible: 
    True,
  breakpoint: 
    -1,
  confl_sents: 
    [],
  confl_pairs: 
    [],
  states: 
    [
      {'h_location': [['John', 0]], 'conscious': [['John', 2]], 'wearing': [['John', 0]], 'h_wet': [['John', 0]], 'hygiene': [['John', 0]], 'location': [['snacks', 0], ['party', 0]], 'exist': [['snacks', 4], ['party', 2]], 'clean': [['snacks', 0], ['party', 0]], 'power': 

In [14]:
cloze_dataset = {p: [] for p in dataset}
order_dataset = {p: [] for p in dataset}

for p in dataset:
  for exid in dataset[p]:
    ex = dataset[p][exid]

    if ex['type'] == None:
      continue
    
    ex_plaus = dataset[p][str(ex['story_id'])]

    if ex['type'] == 'cloze':
      cloze_dataset[p].append(ex)
      cloze_dataset[p].append(ex_plaus) # For every implausible story, add a copy of its corresponding plausible story

    # Exclude augmented ordering examples from dev and test, since the breakpoints aren't always accurate in those
    elif ex['type'] == 'order' and not (p != 'train' and ex['aug']): 
      order_dataset[p].append(ex)
      order_dataset[p].append(ex_plaus)

In [15]:
from www.utils import print_dict
import json
from collections import Counter

data_file = './all_data/www_2s_new.json'
with open(data_file, 'r') as f:
  cloze_dataset_2s, order_dataset_2s = json.load(f)  

for p in cloze_dataset_2s:
  label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
  print('Cloze label distribution (%s):' % p)
  print(label_dist.most_common())
print_dict(cloze_dataset_2s['train'][0])

Cloze label distribution (train):
[(1, 400), (0, 399)]
Cloze label distribution (dev):
[(0, 161), (1, 161)]
Cloze label distribution (test):
[(1, 176), (0, 175)]
{
  example_id: 
    0-C0,
  stories: 
    [
      {'story_id': 0, 'worker_id': 'A1F01FVEPYCPHO', 'type': 'cloze', 'idx': 0, 'aug': False, 'actor': 'Tom', 'location': 'kitchen', 'objects': 'dustbin, microwave, pan, plate, cereal, soup', 'sentences': ['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.'], 'length': 5, 'example_id': '0-C0', 'plausible': False, 'breakpoint': 4, 'confl_sents': [3], 'confl_pairs': [[3, 4]], 'states': [{'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 6]], 'exist': [['dustbin', 4]], 'clean': [['dustbin', 0]], 'power': [['dustbin', 0]], 'functional': [['dustbin', 

In [None]:
from www.dataset.prepro import get_tiered_data, balance_labels
from collections import Counter
tiered_dataset = cloze_dataset_2s

# Debug the code on a small amount of data
if False:
    for k in tiered_dataset:
        tiered_dataset[k] = tiered_dataset[k][:5]

maxStoryLength=168       
tiered_dataset = get_tiered_data(tiered_dataset)

### Object Abstraction

In [None]:
with open(data_file, 'r') as f:
    new_cloze_dataset_2s, order_dataset_2s = json.load(f)  
    
abstract_dataset = new_cloze_dataset_2s
maxStoryLength=168      

if False:
    for k in abstract_dataset:
        abstract_dataset[k] = abstract_dataset[k][:5]

abstract_dataset = get_tiered_data(abstract_dataset)

In [None]:
import pickle
with open("./abstract_tree.pkl", "rb") as tf:
    abstract_dict = pickle.load(tf)

In [None]:
from tqdm import tqdm

In [None]:
for key in abstract_dataset.keys():
    for data in tqdm(abstract_dataset[key]):
        for sample in data['stories']:
            original_sentences=sample['sentences']
            entity_list = [entity['entity'] for entity in sample['entities']]
            abstract_sentences=[]
            for sent in original_sentences:
                abstract_sent=sent
                for entity in entity_list:
                    if entity in abstract_dict.keys() and abstract_dict[entity]['parent']!='entity.n.01':
                        parent_class = abstract_dict[entity]['parent'].split(".")[0]
                        abstract_sent=abstract_sent.replace(entity,parent_class)
                abstract_sentences.append(abstract_sent)
            sample['sentences']=abstract_sentences
            for ent in sample['entities']:
                ent['sentences']=abstract_sentences
                if ent['entity'] in abstract_dict.keys() and abstract_dict[ent['entity']]['parent']!='entity.n.01':
                    ent['entity']= abstract_dict[ent['entity']]['parent'].split(".")[0]

### Synonym Replacement

In [None]:
with open(data_file, 'r') as f:
    new_cloze_dataset_2s, order_dataset_2s = json.load(f)  
    
replace_dataset = new_cloze_dataset_2s
maxStoryLength=168      

if False:
    for k in replace_dataset:
        replace_dataset[k] = replace_dataset[k][:5]

replace_dataset = get_tiered_data(replace_dataset)

In [None]:
np.int = np.int64
np.float = np.float64
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
import nlpaug.flow as nafc
import spacy
nlp = spacy.load("en_core_web_sm")

In [None]:
def synonym_replacement(sentence, entity_list, aug):
    doc = nlp(sentence)
    words = []
    for token in doc:
        if token.pos_ in ['VERB', 'ADJ', 'ADV'] and token.text not in entity_list:
            if token.tag_ in ['VBD', 'VBN']:
                lemma_word = token.lemma_
                augmented_word = aug.augment(lemma_word)[0]
            else:
                augmented_word = aug.augment(token.text)[0]
            words.append(augmented_word)
        else:
            words.append(token.text)
    augmented_sentence = ' '.join(words)
    
    return augmented_sentence


def augment_story_data(dataset, aug):
    for story_pair in tqdm(dataset):
        story_1 = story_pair['stories'][0]
        story_2 = story_pair['stories'][1]

        entity_list_1 = [_['entity'] for _ in story_1['entities']] + [story_1['actor']] + [story_1['location']] + [story_1['objects']]
        entity_list_2 = [_['entity'] for _ in story_2['entities']] + [story_2['actor']] + [story_2['location']] + [story_2['objects']]

        aug_sentence_list_1 = []
        aug_sentence_list_2 = []
        
        for sentence1, sentence2 in zip(story_1['sentences'], story_2['sentences']):
            aug_sentence_1 = synonym_replacement(sentence1, entity_list_1, aug)
            aug_sentence_list_1.append(aug_sentence_1)
            aug_sentence_2 = synonym_replacement(sentence2, entity_list_2, aug)
            aug_sentence_list_2.append(aug_sentence_2)

        story_1['sentences'] = aug_sentence_list_1
        story_2['sentences'] = aug_sentence_list_2

        for entity in story_1['entities']:
            entity['sentences'] = aug_sentence_list_1
        for entity in story_2['entities']:
            entity['sentences'] = aug_sentence_list_2
    return dataset

In [None]:
aug = naw.SynonymAug(p = 0.4)

In [None]:
syno_dataset = {}
syno_dataset['train'] = augment_story_data(replace_dataset['train'], aug)

### GITA Dataset

In [None]:
with open('./all_data/GITA_test.json', 'r') as f:
    gita_dataset = json.load(f)

In [None]:
from transformers import MarianMTModel, MarianTokenizer

model_name = 'Helsinki-NLP/opus-mt-it-en'
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)

In [None]:
import copy
def translate_text(text, model, tokenizer):
    inputs = tokenizer.encode(text, return_tensors="pt", padding=True)
    translated = model.generate(inputs, max_length=50)
    return tokenizer.decode(translated[0], skip_special_tokens=True)

def translate_sample(sample, model, tokenizer):
    trans_sentences = [translate_text(sentence, model, tokenizer) for sentence in sample['sentences']]
    location = sample['location']
    objects = sample['objects'].split(', ')
    trans_location = translate_text(location, model, tokenizer)
    trans_objects = {obj.strip(): translate_text(obj.strip(), model, tokenizer) for obj in objects}
    trans_sample = copy.deepcopy(sample)
    trans_sample['sentences'] = trans_sentences
    trans_sample['location'] = trans_location
    trans_sample['objects'] = ', '.join([trans_objects[obj] for obj in objects])
    for state in trans_sample['states']:
        for key in state:
            for i in range(len(state[key])):
                obj, val = state[key][i]
                obj = obj.strip()
                if obj in trans_objects:
                    trans_obj = trans_objects[obj]
                    state[key][i][0] = trans_obj
    return trans_sample

In [None]:
samplekeys = list(gita_dataset['test'].keys())
translated_gita_dataset = {}
translated_gita_dataset['test'] = {}
for key in samplekeys:
    sample = gita_dataset['test'][key]
    trans_sample = translate_sample(sample, model, tokenizer)
    translated_gita_dataset['test'][key] = trans_sample

In [None]:
with open('./all_data/translated_gita.json', 'w') as f:
    json.dump(translated_gita_dataset, f, indent = 4)

In [None]:
gita_dataset_2s = {}
gita_dataset_2s['test'] = []
for exid in range(117):
    ex_plaus = gita_dataset['test'][str(exid)]  
    if f'{exid}-C0' in gita_dataset['test']:
        ex_implaus = gita_dataset['test'][f'{exid}-C0']
        ex = {}
        ex['example_id'] = f'{exid}-C0'
        ex['stories'] = [ex_plaus, ex_implaus]
        ex['length'] = 5
        ex['label'] = 0
        ex['breakpoint'] = ex_implaus['breakpoint']
        if type(ex_implaus['confl_sents'][0]) == list:
            ex['confl_sents'] = ex_implaus['confl_sents'][0]
        else:
            ex['confl_sents'] = ex_implaus['confl_sents']
        if len(ex_implaus['confl_pairs']) > 1:
            ex['confl_pairs'] = [[ex_implaus['confl_pairs'][0][0], ex_implaus['confl_pairs'][1][0]]]
        else:
            ex['confl_pairs'] = ex_implaus['confl_pairs']
        gita_dataset_2s['test'].append(ex)

In [None]:
gita_dataset_2s['test'][2]['confl_pairs'] = [[3,4]]
gita_dataset_2s['test'][10]['confl_pairs'] = [[1,2]]
gita_dataset_2s['test'].pop(74)
gita_dataset_2s['test'].pop(69)

### Knowledge Graph Integration

In [18]:
import networkx as nx
import conceptnet5
import requests
import matplotlib.pyplot as plt
nlp = spacy.load("en_core_web_sm")

In [20]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
comet_model = AutoModelForSeq2SeqLM.from_pretrained("mismayil/comet-bart-ai2")
comet_tokenizer = AutoTokenizer.from_pretrained("mismayil/comet-bart-ai2")

In [24]:
def extract_entities(text):
    doc = nlp(text)
    entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "GPE", "ORG", "PRODUCT", "FAC", "MISC"]]
    for chunk in doc.noun_chunks:
        head = chunk.root
        if head.pos_ == "NOUN":
            entities.append(head.text)
    return list(set(entities))

def get_conceptnet_relations(entity):
    base_url = "http://api.conceptnet.io/c/en/"
    url = f"{base_url}{entity}"
    try:
        response = requests.get(url)
        response.raise_for_status()
        data = response.json()
    except requests.RequestException as e:
        print(f"Error querying ConceptNet API: {e}")
        return []
    
    relations = []
    for edge in data.get("edges", []):
        weight = edge.get("weight", 0)
        if weight > 1:
            start = edge["start"]["@id"]
            end = edge["end"]["@id"]
            relation = edge["rel"]["label"]

            if start.startswith("/c/en/") and end.startswith("/c/en/"):
                start_label = edge["start"]["label"]
                end_label = edge["end"]["label"]
                relations.append((start_label, relation, end_label))
    
    return relations

def generate_comet_knowledge(input_text, category="xAttr", max_length=50):
    prompt = f"{input_text} {category}"
    inputs = comet_tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = comet_model.generate(inputs["input_ids"], max_length=max_length, num_beams=5)
    generated_text = comet_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

def augment_sentence_with_knowledge(sentence, use_comet=False):
    entities = extract_entities(sentence)

    conceptnet_info = []
    comet_info = []
    
    for entity in entities:
        conceptnet_relations = get_conceptnet_relations(entity)
        for rel in conceptnet_relations:
            conceptnet_info.append(f"{entity} is related to {rel[2]}.")
        
        if use_comet:
            xAttr_knowledge = generate_comet_knowledge(entity, category="xAttr")
            xEffect_knowledge = generate_comet_knowledge(entity, category = 'xEffect')
            oEffect_knowledge = generate_comet_knowledge(entity, category="oEffect")
            comet_info.append(f"For {entity}: xAttr = {xAttr_knowledge}, xEffect = {xEffect_knowledge}, oEffect = {oEffect_knowledge}.")
    
    all_info = conceptnet_info + comet_info
    unique_info = list(set(all_info))
    descriptions = " ".join(unique_info)
    augmented_sentence = sentence + " (" + descriptions + ")"
    
    return augmented_sentence

def augment_sentences_with_knowledge(sentences, use_comet=False):
    augmented_sentences = []
    for sentence in sentences:
        augmented_sentence = augment_sentence_with_knowledge(sentence, use_comet)
        augmented_sentences.append(augmented_sentence)
    
    return augmented_sentences

In [25]:
def add_description(dataset):
    newdataset = dataset.copy()
    for i in range(len(newdataset['train'])):
        stories = newdataset['train'][i]['stories']
        for j in range(len(stories)):
            sentences = stories[j]['sentences']
            augmented_sentences = augment_sentences_with_knowledge(sentences, use_comet = True)
            #print(augmented_sentences)
            newdataset['train'][i]['stories'][j]['sentences'] = augmented_sentences
    return newdataset

In [None]:
kg_dataset = add_description(kg_dataset)

In [27]:
import json
with open("./all_data/kg_aug_data.json", "w") as json_file:
    json.dump(kg_dataset, json_file, indent=4)

### Process data for model training

In [None]:
from www.dataset.prepro_for_aug import get_tiered_data, balance_labels
from www.dataset.featurize_for_aug import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter

gita_tiered_dataset = gita_dataset_2s
gita_tiered_dataset = get_tiered_data(gita_tiered_dataset)

In [None]:
aug_tiered_dataset = {}
aug_tiered_dataset['train'] = tiered_dataset['train'] + gita_tiered_dataset['test']
aug_tiered_dataset['dev'] = tiered_dataset['dev']
aug_tiered_dataset['test'] = tiered_dataset['test']

In [None]:
train_spans = False
seq_length = 16 # Max sequence length to pad to

aug_tiered_dataset = add_bert_features_tiered(aug_tiered_dataset, tokenizer, seq_length, add_segment_ids=True)

aug_tiered_tensor_dataset = {}
max_story_length = 7
for p in aug_tiered_dataset:
    aug_tiered_tensor_dataset[p] = get_tensor_dataset_tiered(aug_tiered_dataset[p], max_story_length, add_segment_ids = True)