## Global Config

In [15]:

# download configuratiins
download_from_remote = True
load_from_local = True
save_to_local = True
local_dataset = "./dataset/fever_dataset"
local_adversarial = "./dataset/adversarial_dataset"
%mkdir -p {local_dataset}
%mkdir -p {local_adversarial}

# data analytics
print_statistics = False
save_info = False
info = "./info/"
%mkdir -p {info}

# augmentation
syn_hyp_augment = False

## Imports

In [11]:
# general imports
import os
import re
import string
import random
from tqdm import tqdm

# visualization and statistics
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter

# data manipulation
import numpy as np
import pandas as pd
from typing import Dict
from datasets import(
  Dataset, 
  load_dataset, 
  load_from_disk, 
  concatenate_datasets,
) 

# nltk
import nltk
from nltk.corpus import wordnet as wn
nltk.download('wordnet')
nltk.download('wordnet')

# spacy
import spacy 
# Load the SpaCy model
nlp_spacy = spacy.load("en_core_web_sm")

# huggingface
from transformers import DataCollatorWithPadding
from transformers import (
    AutoTokenizer,
    set_seed,
)

# set seeds
set_seed(42)
np.random.seed(42)
random.seed(42)

[nltk_data] Downloading package wordnet to /home/leeoos/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package wordnet to /home/leeoos/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Data

In [3]:
# download data or just load from local 

download_from_remote = not(os.path.exists(local_dataset) and os.path.exists(local_adversarial))

if download_from_remote:
    print("Downloading data from remote repository")

    # load chunk of FEVER datyaset
    fever_dataset = load_dataset("tommasobonomo/sem_augmented_fever_nli", trust_remote_code=True)

    # load adversarial 
    adversarial_testset = load_dataset("iperbole/adversarial_fever_nli", trust_remote_code=True)

    # structure of the dataset
    print(fever_dataset)

    if save_to_local:
        print(f"Save data in local {local_dataset}")
        fever_dataset.save_to_disk(local_dataset)
        print(f"Save adversarial dataset in {local_adversarial}")
        adversarial_testset.save_to_disk(local_adversarial)

elif load_from_local:
    print(f"Load data from local repository")
    fever_dataset = load_from_disk(local_dataset)
    adversarial_testset = load_from_disk(local_adversarial)
    
print("Done!")

print(f"Train ssplit length: {len(fever_dataset['train'])}")
max_id = max(fever_dataset['train']['id'])
print(f"Max id in train split: {max_id}")

Load data from local repository
Done!
Train ssplit length: 51086
Max id in train split: 99998


In [4]:
fever_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 51086
    })
    validation: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2288
    })
    test: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2287
    })
})

## Exploration

In [4]:
#@title Tokenization function for datapoint visualization

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

label_map = {
    'ENTAILMENT': 0,
    'NEUTRAL': 1,
    'CONTRADICTION': 2,
    'NOT ENOUGH INFO': None
}

def tokenize_function(examples):
    examples['label'] = [label_map[label] for label in examples['label']]
    return tokenizer(examples['premise'], examples['hypothesis'], padding=True, truncation=True)

In [5]:
#@title Exploration utils

def pretty_print_dict(d, indent=0):
    for key, value in d.items():
        print(' ' * indent + str(key) + ':', end=' ')
        if isinstance(value, dict):
            print()
            pretty_print_dict(value, indent + 4)
        else:
            print(value)


def plot_labels_distribution(target_set, title=''):

    labels = [
        'ENTAILMENT',
        'NEUTRAL',
        'CONTRADICTION',
    ]

    label_counts = {}
    for label in target_set['label']:
        if label not in label_counts:
            label_counts[label] = 0
        label_counts[label] += 1

    plt.bar(labels, label_counts.values())
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.title(f'Distribution of labels in {title}')
    plt.show()
    print()
        

def plot_lengths_distribution(target_set, title='', compare_length=False):
    # Extract premises and hypotheses
    premises = [item['premise'] for item in target_set]
    hypotheses = [item['hypothesis'] for item in target_set]

    # Compute lengths
    premise_lengths = [len(premise.split()) for premise in premises]
    hypothesis_lengths = [len(hypothesis.split()) for hypothesis in hypotheses]

    # Plotting length distributions
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    sns.histplot(premise_lengths, bins=50, kde=True, ax=axes[0], color='blue', log_scale=(False, True))
    axes[0].set_title('Premise Length Distribution')
    axes[0].set_xlabel('Number of Words')
    axes[0].set_ylabel('Frequency')

    sns.histplot(hypothesis_lengths, bins=50, kde=True, ax=axes[1], color='green', log_scale=(False, True))
    axes[1].set_title('Hypothesis Length Distribution')
    axes[1].set_xlabel('Number of Words')
    axes[1].set_ylabel('Frequency')

    fig.suptitle(f'Premise and Hypothesis Length Distribution in {title}')
    plt.show()

    if compare_length:
        # Plot premise vs hypothesis length scatter plot
        plt.figure(figsize=(8, 6))
        plt.scatter(premise_lengths, hypothesis_lengths, alpha=0.5, s=1)
        plt.title('Premise vs Hypothesis Length')
        plt.xlabel('Premise Length')
        plt.ylabel('Hypothesis Length')
        plt.yscale('log')
        plt.xscale('log')
        plt.grid(True, which="both", ls="--")
        plt.show()


def plot_vocb_distribution():
    # Tokenize words
    # premise_words = [word for text in premises for word in word_tokenize(text)]
    # hypothesis_words = [word for text in hypotheses for word in word_tokenize(text)]

    # # Compute word frequencies
    # premise_word_freq = Counter(premise_words)
    # hypothesis_word_freq = Counter(hypothesis_words)

    # # Plotting the most common words
    # premise_common_words = premise_word_freq.most_common(20)
    # hypothesis_common_words = hypothesis_word_freq.most_common(20)

    # fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    # axes[0].barh([word[0] for word in premise_common_words], [word[1] for word in premise_common_words])
    # axes[0].set_title('Premise Common Words')
    # axes[1].barh([word[0] for word in hypothesis_common_words], [word[1] for word in hypothesis_common_words])
    # axes[1].set_title('Hypothesis Common Words')
    # plt.show()
    ...

# Prepare data for plotting
# lengths_and_labels = [(length, item['label']) for length, item in zip(premise_lengths, train_dataset)]
# premise_df = pd.DataFrame(lengths_and_labels, columns=['length', 'label'])

# lengths_and_labels = [(length, item['label']) for length, item in zip(hypothesis_lengths, train_dataset)]
# hypothesis_df = pd.DataFrame(lengths_and_labels, columns=['length', 'label'])

# # Plot premise lengths across classes
# plt.figure(figsize=(12, 6))
# sns.boxplot(x='label', y='length', data=premise_df)
# plt.title('Premise Length Distribution Across Classes')
# plt.show()

# # Plot hypothesis lengths across classes
# plt.figure(figsize=(12, 6))
# sns.boxplot(x='label', y='length', data=hypothesis_df)
# plt.title('Hypothesis Length Distribution Across Classes')
# plt.show()

# from sklearn.feature_extraction.text import TfidfVectorizer
# from sklearn.metrics.pairwise import cosine_similarity
# import numpy as np

# # Create combined list of premises and hypotheses
# texts = premises + hypotheses

# # Compute TF-IDF vectors
# tfidf = TfidfVectorizer().fit_transform(texts)

# # Compute cosine similarity between each premise and its corresponding hypothesis
# cosine_sim = [cosine_similarity(tfidf[i], tfidf[i + len(premises)])[0][0] for i in range(len(premises))]

# # Plot cosine similarity distribution
# plt.figure(figsize=(8, 6))
# plt.hist(cosine_sim, bins=50, alpha=0.5, color='purple')
# plt.title('Cosine Similarity Distribution Between Premise and Hypothesis')
# plt.xlabel('Cosine Similarity')
# plt.ylabel('Frequency')
# plt.show()


In [6]:
# statistics about the regular dataset
if print_statistics:
  plot_labels_distribution(fever_dataset['train'], title="Fever Train Set")
  plot_lengths_distribution(fever_dataset['train'], title="Fever Train Set", compare_length=True)

In [118]:
#@title Dataset structure

# print(fever_dataset['train'][0]['srl'].keys())
# print(fever_dataset['train'][0]['srl']['hypothesis']['tokens'])
# print(fever_dataset['train'][0]['srl']['hypothesis']['annotations'])

# for token in fever_dataset['train'][0]['srl']['premise']['tokens']:
#     print(token)

# for annotation in fever_dataset['train'][0]['srl']['premise']['annotations']:
#     for key, value in annotation.items():
#         print(f"{key}:\t{value}")
#     print()


# WSD exploration

# sample_range = 10 #len(fever_dataset['train'])
# loop = tqdm(range(sample_range))

# pos_info = dict()

# for i in loop:

#   data =  fever_dataset['train'][i]
#   sample_id = data['id']

#   hypothesis = data['hypothesis']
#   hyp_wsd = data['wsd']['hypothesis']

#   for hyp_wsd_dict in hyp_wsd:
    
#     pos  = hyp_wsd_dict['pos']
#     pos_info[pos] = 1 if pos not in pos_info else pos_info[pos] + 1

# pos_info

# fever_dataset['train'][0:2]['srl']

dict_keys(['premise', 'hypothesis'])
[{'index': 0, 'rawText': 'Roman'}, {'index': 1, 'rawText': 'Atwood'}, {'index': 2, 'rawText': 'is'}, {'index': 3, 'rawText': 'a'}, {'index': 4, 'rawText': 'content'}, {'index': 5, 'rawText': 'creator'}, {'index': 6, 'rawText': '.'}]
[{'tokenIndex': 2, 'verbatlas': {'frameName': 'COPULA', 'roles': [{'role': 'Theme', 'score': 1.0, 'span': [0, 2]}, {'role': 'Attribute', 'score': 1.0, 'span': [3, 6]}]}, 'englishPropbank': {'frameName': 'be.01', 'roles': [{'role': 'ARG1', 'score': 1.0, 'span': [0, 2]}, {'role': 'ARG2', 'score': 1.0, 'span': [3, 6]}]}}]


In [7]:
#@title SRL exploration: collect frame-names and roles 

def get_srl_info(dataset):

  sample_range = len(dataset)
  loop = tqdm(range(sample_range))

  relevant_info = dict()

  set_of_va_frames = dict() # verb atlas frames
  set_of_pb_frames = dict() # prop bank frames

  verbs_freqs = dict()

  set_of_va_roles = set() # verb atlas frames
  set_of_pb_roles = set() # prop bank frames

  for i in loop:

    data =  dataset[i]
    sample_id = data['id']
    
    tokens = data['srl']['premise']['tokens']
    annotations = data['srl']['premise']['annotations']

    for annotation in annotations:

      token_index = annotation['tokenIndex']
      verb = tokens[token_index]['rawText']
      verbs_freqs[verb] =  1 if verb not in  verbs_freqs else verbs_freqs[verb] + 1

      vb_frame = annotation['verbatlas']['frameName']
      pb_frame = annotation['englishPropbank']['frameName']

      set_of_va_frames[vb_frame] = 1 if vb_frame not in set_of_va_frames else set_of_va_frames[vb_frame] + 1 
      set_of_pb_frames[pb_frame] = 1 if pb_frame not in set_of_pb_frames else set_of_pb_frames[pb_frame] + 1 

      va_roles = annotation['verbatlas']['roles']
      pb_roles = annotation['englishPropbank']['roles']

      for role in va_roles:
        set_of_va_roles.add(role['role'])

      for role in pb_roles:
        set_of_pb_roles.add(role['role'])

  return set_of_va_frames, set_of_pb_frames, set_of_va_roles, set_of_pb_roles, verbs_freqs
  

In [8]:
if save_info:

  set_of_va_frames, set_of_pb_frames, set_of_va_roles, set_of_pb_roles, verbs_freqs = get_srl_info(fever_dataset['train'])

  verbs_freqs = sorted(verbs_freqs.items(), key=lambda item: item[1], reverse=True)
  set_of_pb_frames = sorted(set_of_pb_frames.items(), key=lambda item: item[1], reverse=True)
  set_of_va_frames = sorted(set_of_va_frames.items(), key=lambda item: item[1], reverse=True)

  print(f"Number of Verb Atlas frames: {len(set_of_va_frames)}")
  print(f"Number of Verb Atlas roles: {len(set_of_va_roles)}")

  print(f"Number of Propbank frames: {len(set_of_pb_frames)}")
  print(f"Number of Propbank roles: {len(set_of_pb_roles)}")

  with open(info + "va_roles.txt", "w") as va_roles:
    print(f"Saving all Verb Atlas roles of the dataset into {info + 'va_roles.txt'}")
    for elem in set_of_va_roles: va_roles.write(f"{elem}\n")

  with open(info + "va_frames.txt", "w") as va_frames:
    print(f"Saving all Verb Atlas frames of the dataset into {info + 'va_frames.txt'}")
    for elem in set_of_va_frames: va_frames.write(f"{elem}\n")


  with open(info + "pb_roles.txt", "w") as pb_roles:
    print(f"Saving all Propbank roles of the dataset into {info + 'pb_roles.txt'}")
    for elem in set_of_pb_roles: pb_roles.write(f"{elem}\n")


  with open(info + "pb_frames.txt", "w") as pb_frames:
    print(f"Saving all Propbank frames of the dataset into {info + 'pb_frames.txt'}")
    for elem in set_of_pb_frames: pb_frames.write(f"{elem}\n")

  with open(info + "verbs_freqs.txt", "w") as verbs:
    print(f"Saving all verbs frequencies count into {info + 'verbs_freqs.txt'}")
    for elem in set_of_pb_frames: verbs.write(f"{elem}\n")


## Augmentation
The following prompt can be used to correct the grammar of a modified hypotesis so if just a not is added the phrase is transformed in negative form:

"Correct the grammar in the following inputs, rephrase the input if necessary to make them more accurate. Provide just the correct version, no explanation."

Ask for:
symmetric relation --> Marry
antisymmetric relation --> Kill

born
sell


In [14]:
#@title Augmentation utils

def join_strings_smartly(words):
    """ Joins a list of words smartly:
    - Adds spaces between words when appropriate.
    - Avoids adding spaces before punctuation.
    """
    punctuation = {'.', ',', ';', ':', '!', '?',')'}
    result = words[0]
    prev = result

    for word in words[1:]:
      if word in punctuation or \
        ("'" in prev) or \
        word.startswith("'") or \
        ("." in prev and "." in word) or \
        ("(" in prev) :
        # add word without space before
        result += word
      else:
        # add with space before
        result += " " + word
      # keep track of previous word  
      prev = word

    return result

def get_synset_from_id(synset_id):
    if synset_id == 'O':
        return None
    try:
        offset = int(''.join(filter(str.isdigit, synset_id)))
        pos = synset_id[-1]
        synset = wn.synset_from_pos_and_offset(pos, offset)
        return synset
    except:
        print("exception")
        return None
    

def get_related_word(synset, pos): 

    info = dict()

    # Map POS tags to WordNet POS tags
    pos_map = {
        'NOUN': wn.NOUN,
        'VERB': wn.VERB,
        'ADJ': wn.ADJ,
        'ADV': wn.ADV
    }
    
    if pos not in pos_map:
        return None 
    
    # get hypernyms
    hypernyms = synset.hypernyms()
    if not hypernyms:
        return None
    hypernyms = synset.hypernyms()
    hypernym_words = set()
    for hypernym in hypernyms:
        hypernym_words.update(hypernym.lemma_names())
    info['hypernyms'] = list(hypernym_words)

    # get synonyms 
    synonyms = synset.lemmas()
    if not synonyms:
        return None
    info['synonyms'] = [synonym.name() for synonym in synonyms if synonym.name() != synset.lemmas()[0].name()]

    return info


def extract_names(wsd_data):
    names = []
    current_name = []

    for entry in wsd_data:
        if entry['pos'] == 'PROPN':
            current_name.append(entry['text'])
        else:
            if current_name:
                names.append(' '.join(current_name))
                current_name = []

    # catch any remaining name at the end
    if current_name:
        names.append(' '.join(current_name))
    
    return names


def extract_partial_match_name(text, name_list):
    # tokenize the input text
    words = re.findall(r'\b\w+\b', text)
    names = set()
    
    # tterate through each name in the name list
    for name in name_list:
        name_parts = name.split()

        # check for partial match in the tokenized words
        for i in range(len(words) - len(name_parts) + 1):
            if words[i:i + len(name_parts)] == name_parts[:len(words[i:i + len(name_parts)])]:
              names.add(name)
              
    if names: return names
    return None

In [81]:
#@title Augmentation by Synonyms and Hypernyms

if syn_hyp_augment:
  sample_range = len(fever_dataset['train'])
  loop = tqdm(range(sample_range))

  new_samples = {
      'id': [],
      'premise': [],
      'hypothesis': [],
      'label': [],
      'wsd': [None for i in range(sample_range)],
      'srl': [None for i in range(sample_range)]
  }

  syn_dict = dict()
  syn_idx = 0
  progressive_id = int(max_id)

  for i in loop:

    data =  fever_dataset['train'][i]
    sample_id = data['id']

    premise = data['premise']
    label = data['label']
    hyp_wsd = data['wsd']['hypothesis']

    new_samples['id'].append(str(progressive_id))
    new_samples['premise'].append(premise)
    new_samples['label'].append(label)
    progressive_id += 1

    hypothesis = data['hypothesis']
    new_hypotesys = []

    for hyp_wsd_dict in hyp_wsd:

      word = hyp_wsd_dict['text']
      pos = hyp_wsd_dict['pos']
      offset = hyp_wsd_dict['wnSynsetOffset']
      synset = get_synset_from_id(offset)

      related_words = None
      if synset: related_words = get_related_word(synset, pos)

      if related_words:
        hypernyms = related_words['hypernyms']
        synonyms = related_words['synonyms']

        if not synonyms : 
          new_hypotesys.append(word)
          continue

        synonym = synonyms[syn_idx % (len(synonyms))]

        if synonym in syn_dict and syn_dict[synonym] > 10:
          syn_idx += 1
          synonym = synonyms[syn_idx % (len(synonyms))]

        syn_dict[synonym] = 1 if  synonym not in syn_dict else syn_dict[synonym] + 1
        # if synonym in syn_dict 

        syn_idx += 1

        if '_' in synonym: synonym = synonym.replace('_', ' ')
        new_hypotesys.append(synonym)

      else: new_hypotesys.append(word)


    new_hypotesys = join_strings_smartly(new_hypotesys)
    new_samples['hypothesis'].append(new_hypotesys)
    # print(f"new: {new_hypotesys} \n")

  # print(new_samples)

100%|██████████| 10/10 [00:00<00:00, 301.98it/s]


In [85]:
#@title Add new samples to the original dataset
 
augmentation = Dataset.from_dict(new_samples)
print(f"Augmentation type: {type(augmentation)}")

augmented_dataset = concatenate_datasets([fever_dataset['train'], augmentation])
print(f"Augmented dataset type: {type(augmented_dataset)}")
print(f"Train split length: {len(fever_dataset['train'])}")
print(f"Train split augmentated: {len(augmented_dataset)}")

Augmentation type: <class 'datasets.arrow_dataset.Dataset'>
Augmented dataset type: <class 'datasets.arrow_dataset.Dataset'>
Train split length: 51086
Train split augmentated: 51096


In [91]:
sample_range = 10 
# sample_range = len(fever_dataset['train'])

locations = set()
# loop = tqdm(range(sample_range))

# for i in loop:
for i in range(sample_range):

    data = fever_dataset['train'][i]

    # id
    id =  data['id']

    # premise
    sentence = data['premise']
    hypothesis = data['hypothesis']

    # wsd
    wsd =  data['wsd']

    # srl
    possible_locations = []
    tokens = data['srl']['premise']['tokens']
    annotations = data['srl']['premise']['annotations']

    # print(tokens)
    # print(annotations)

    for annotation in annotations:
        # print(annotation['verbatlas']['roles'])
        token_idx = annotation['tokenIndex'] 
        roles = annotation['verbatlas']['roles']
        # index = annotation['verbatlas']

        location = False
        agent = False

        # print(index)
        text_location = ""
        for role in roles:
            if role['role'] == 'Location':
                span = role['span']

                for item in tokens:
                    if span[0] <= item['index'] <= span[1]:
                        text_location += item['rawText'] + " "

                # filtered_texts = [item['rawText'] for item in tokens if span[0] <= item['index'] <= span[1]]
                # text_location = " ".join(filtered_texts)
                # filtered_words = [word for word in filtered_texts if any(char.isupper() for char in word)]

                if text_location:
                    # possible_locations.append(text_location)
                    # print(f"sample id: {id}")
                    # print(f"text: {text_location}")

                    entities =  nlp_spacy(text_location).ents
                    # print(f"entities: {entities}")
                    if entities:
                        found_target = False
                        for ent in entities:
                            if ent.label_ == "GPE":
                                found_target = True
                                locations.add(ent.text)

                                # print info related to location
                                print(sentence)
                                print(tokens)
                                print(annotation['verbatlas'])
                                print(hypothesis)
                                # print(f"target entity: {ent.text}")
                            # else:
                            #     print("no target ")
                        # if not found_target :  print("no location target ")
                    else:
                        # print("no target ")
                        ...
                    
                    # print("----------------------------------------- \n")
    

John Wick: Chapter 2 . The film premiered in Los Angeles on January 30 , 2017 , and was theatrically released in the United States on February 10 , 2017 .
[{'index': 0, 'rawText': 'John'}, {'index': 1, 'rawText': 'Wick'}, {'index': 2, 'rawText': ':'}, {'index': 3, 'rawText': 'Chapter'}, {'index': 4, 'rawText': '2'}, {'index': 5, 'rawText': '.'}, {'index': 6, 'rawText': 'The'}, {'index': 7, 'rawText': 'film'}, {'index': 8, 'rawText': 'premiered'}, {'index': 9, 'rawText': 'in'}, {'index': 10, 'rawText': 'Los'}, {'index': 11, 'rawText': 'Angeles'}, {'index': 12, 'rawText': 'on'}, {'index': 13, 'rawText': 'January'}, {'index': 14, 'rawText': '30'}, {'index': 15, 'rawText': ','}, {'index': 16, 'rawText': '2017'}, {'index': 17, 'rawText': ','}, {'index': 18, 'rawText': 'and'}, {'index': 19, 'rawText': 'was'}, {'index': 20, 'rawText': 'theatrically'}, {'index': 21, 'rawText': 'released'}, {'index': 22, 'rawText': 'in'}, {'index': 23, 'rawText': 'the'}, {'index': 24, 'rawText': 'United'}, {'in

In [17]:
# add new samples to the dataset
filtered_dataset = fever_dataset['train'].filter(lambda example: 'is married to ' in example['premise'].lower()) # or 'is a' in example['hypothesis'].lower())

Filter: 100%|██████████| 51086/51086 [01:12<00:00, 700.84 examples/s]


In [18]:
filtered_dataset['premise']

['Ad-Rock . He is married to musician and feminist activist Kathleen Hanna .',
 'Claire Danes . She is married to actor Hugh Dancy , with whom she has one child .',
 "Claire Danes . She is married to actor Hugh Dancy , with whom she has one child . Hugh Michael Horace Dancy ( born 19 June 1975 ) is an English actor and model . He is best known for his roles as Will Graham in the television series Hannibal for which he was nominated for two Critics ' Choice Television Award for Best Actor in a Drama Series and as Prince Charmont in Ella Enchanted . In 2006 , he was nominated for a Primetime Emmy Award for his portrayal of the Earl of Essex in the Channel 4 miniseries Elizabeth I.",
 "Claire Danes . She is married to actor Hugh Dancy , with whom she has one child . Hugh Dancy . He is best known for his roles as Will Graham in the television series Hannibal for which he was nominated for two Critics ' Choice Television Award for Best Actor in a Drama Series and as Prince Charmont in Ella 

In [19]:
len(filtered_dataset['premise'])

63

In [112]:
filtered_dataset[0]['srl']['premise']

{'tokens': [{'index': 0, 'rawText': 'Ad'},
  {'index': 1, 'rawText': '-'},
  {'index': 2, 'rawText': 'Rock'},
  {'index': 3, 'rawText': '.'},
  {'index': 4, 'rawText': 'He'},
  {'index': 5, 'rawText': 'is'},
  {'index': 6, 'rawText': 'married'},
  {'index': 7, 'rawText': 'to'},
  {'index': 8, 'rawText': 'musician'},
  {'index': 9, 'rawText': 'and'},
  {'index': 10, 'rawText': 'feminist'},
  {'index': 11, 'rawText': 'activist'},
  {'index': 12, 'rawText': 'Kathleen'},
  {'index': 13, 'rawText': 'Hanna'},
  {'index': 14, 'rawText': '.'}],
 'annotations': [{'tokenIndex': 1,
   'verbatlas': {'frameName': 'AUXILIARY', 'roles': []},
   'englishPropbank': {'frameName': 'be.03', 'roles': []}},
  {'tokenIndex': 5,
   'verbatlas': {'frameName': 'COPULA',
    'roles': [{'role': 'Theme', 'score': 1.0, 'span': [4, 5]},
     {'role': 'Attribute', 'score': 1.0, 'span': [6, 14]}]},
   'englishPropbank': {'frameName': 'be.01',
    'roles': [{'role': 'ARG1', 'score': 1.0, 'span': [4, 5]},
     {'role': 

In [13]:
# dataset = filtered_dataset
dataset = fever_dataset['train']

sample_range = len(dataset)
loop = tqdm(range(sample_range))

relational_graph = dict()
old_max_span = 1_000_000

for i in loop:

  data =  dataset[i]
  sample_id = data['id']
  premise = data['premise']
  hypothesis = data['hypothesis']

  # print(f"premise: {premise}")
  
  # srl info
  tokens = data['srl']['premise']['tokens']
  annotations = data['srl']['premise']['annotations']

  # wsd info
  wsd = data['wsd']['premise']
  proper_nouns = extract_names(wsd)
  # print(proper_nouns)

  for annotation in annotations:

    token_index = annotation['tokenIndex']
    verb = tokens[token_index]['rawText']

    # frame = annotation['verbatlas']['frameName']
    frame = annotation['englishPropbank']['frameName']
    # relational_graph[pb_frame]

    # if frame not in  ['ALLY_ASSOCIATE_MARRY']: continue
    if frame not in  ['marry.01']: continue
    if frame not in relational_graph: relational_graph[frame] = dict()
    print(premise)

    # roles = annotation['verbatlas']['roles']
    roles = annotation['englishPropbank']['roles']

    try:
      span_begin = roles[0]['span'][0]
      span_end = roles[-1]['span'][1]
    except:
      # print(roles)
      continue

    if span_begin > 0 and span_begin < old_max_span: 
      span_begin = 0
      old_max_span = span_end

    sentence = ""
    try: 
      sentence = " ".join([tokens[index]['rawText']  for index in range(span_begin, span_end + 1)]) # if tokens[index]['rawText'] in names])
    except:
      sentence = " ".join([tokens[index]['rawText'] for index in range(span_begin, span_end)]) #if tokens[index]['rawText'] in names])

    # print(sentence)
    target = extract_partial_match_name(sentence, proper_nouns)
    # print(target)
    if not target: continue

    for elem in target:
      if elem not in relational_graph[frame]: relational_graph[frame][elem] = [sample_id]
      else: relational_graph[frame][elem].append(sample_id)




    # srl_roles = []
    # for role in roles:
    #   # print(role)

    #   span = role['span']
    #   role_tag = role['role']

    #   # print(role_tag)
    #   if role_tag not in accetable_roles : continue

    #   curr_roles = ""
    #   try: 
    #     curr_roles = " ".join([tokens[index]['rawText']  for index in range(span[0], span[1] + 1)]) # if tokens[index]['rawText'] in names])
    #   except:
    #     curr_roles = " ".join([tokens[index]['rawText'] for index in range(span[0], span[1])]) #if tokens[index]['rawText'] in names])

    #   # print(names)
    #   # print(curr_roles)
    #   target = extract_partial_match_name(curr_roles, proper_nouns)

    #   if target: 
    #     print(premise)
    #     print(target)
    #     print(role_tag)

    #     if 'Agent' in role_tag:
    #       if target not in relational_graph:
    #         relational_graph[target] = [(frame, 1)]
        
    #     else:
    #       if frame not in relational_graph:
    #         relational_graph[frame] = [(target, sample_id)]

    #     # print(names)
    #     # print(wsd)
    #     # print(tokens)
    #     print(annotation['verbatlas'])

    #     # print(target)
    #     # srl_roles.append((curr_roles, role_tag))
    #   else: continue


      # if frame not in relational_graph:
      #   relational_graph[frame] = [(curr_roles, role_tag, sample_id)]
      # else:
      #   relational_graph[frame].append((curr_roles, role_tag, sample_id))

  0%|          | 0/51086 [00:00<?, ?it/s]

  0%|          | 0/51086 [00:00<?, ?it/s]


NameError: name 'extract_names' is not defined

In [None]:
pretty_print_dict(relational_graph)

In [116]:
def graph_query(graph, frame, given_id, query):
   if query in 'CONTRADDICTION':
      return [name for name, ids in graph[frame].items() if given_id not in ids]
   
   elif query in 'ENTAILMENT':
      return [name for name, ids in graph[frame].items() if given_id in ids]

   else:
      print("Error: invalid query")


id = '93454'
filtered_values = graph_query(
   graph=relational_graph, 
   frame='marry.01', 
   given_id=id,
   query='C'
)
print(filtered_values)
  

['Rachel', 'Hathaway', 'Anne Hathaway', 'Wyatt Earp', 'Urilla Sutherland Earp', 'Earp', 'Millionaire', 'Peggy Sue Got', 'Paul McCartney', 'Children', 'Beatles', 'Jennifer Garner', 'Garner', 'Scott Foley', 'Ellen Pompeo', 'Chris Ivery', 'Pompeo', 'Harald', 'Harald V', 'Sonja Haraldsen', 'Norway', 'Germany', 'King Henry', 'Holy Roman Emperor Henry V', 'England', 'Empress Matilda', 'Brigham Young University', 'Mitt Romney', 'Ann Romney', 'BYU', 'Ellen DeGeneres', 'Portia de Rossi', 'Brad Pitt', 'Jennifer Aniston', 'Justin Theroux', 'Elizabeth Swann', 'World', 'Lady McCartney', 'Linda Louise', 'September', 'April', 'Ben Affleck', 'Joan Crawford', 'Crawford', 'George VI', 'Lyon', 'Elizabeth', 'Lady Elizabeth Bowes', 'Greece', 'Wales', 'Prince', 'Charles', 'Edinburgh', 'Duke', 'Philip', 'Elizabeth II', 'Denmark', 'Christie Brinkley', 'Billy Joel', 'Brinkley', 'Fearless Vampire Killers', 'August', 'Sharon Marie Tate Polanski', 'Tate', 'Roman Polanski', 'January', 'Peggy Sue', 'Chairman Alfred

In [31]:
len(relational_graph['ALLY_ASSOCIATE_MARRY'])

8