## Global Config

In [2]:
# download configuratiins
download_from_remote = True
load_from_local = True
save_to_local = True
local_dataset = "./dataset/fever_dataset"
local_adversarial = "./dataset/adversarial_dataset"
%mkdir -p {local_dataset}
%mkdir -p {local_adversarial}

# data analytics
test_tokenizer = False
print_statistics = False

# dataset exploration
general_structure = True
show_ids = False
show_labels = False
wsd_exploration = False
srl_exploration = True

# semantic roles
save_info = False
info = "./info/"
%mkdir -p {info}


# augmentation
syn_hyp_augment = True
save_graph = False
load_graph = True
graph = "./relational_graph"
%mkdir -p {graph}

## Imports

In [3]:
# general imports
import os
import re
import json
import string
import random
from tqdm import tqdm

# visualization and statistics
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter

# data manipulation
import numpy as np
import pandas as pd
from typing import Dict
from datasets import(
  Dataset, 
  load_dataset, 
  load_from_disk, 
  concatenate_datasets,
) 

# nltk
import nltk
from nltk.corpus import wordnet as wn
nltk.download('wordnet')
nltk.download('wordnet')

# spacy
import spacy 
# Load the SpaCy model
nlp_spacy = spacy.load("en_core_web_sm")

# huggingface
from transformers import DataCollatorWithPadding
from transformers import (
    AutoTokenizer,
    set_seed,
)

# set seeds
set_seed(42)
np.random.seed(42)
random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package wordnet to /home/leeoos/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package wordnet to /home/leeoos/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Data

In [5]:
#@title Download data or just load from local 

download_from_remote = not(os.path.exists(local_dataset) and os.path.exists(local_adversarial))

if download_from_remote:
    print("Downloading data from remote repository")

    # load chunk of FEVER datyaset
    fever_dataset = load_dataset("tommasobonomo/sem_augmented_fever_nli", trust_remote_code=True)

    # load adversarial 
    adversarial_testset = load_dataset("iperbole/adversarial_fever_nli", trust_remote_code=True)

    # structure of the dataset
    print(fever_dataset)

    if save_to_local:
        print(f"Save data in local {local_dataset}")
        fever_dataset.save_to_disk(local_dataset)
        print(f"Save adversarial dataset in {local_adversarial}")
        adversarial_testset.save_to_disk(local_adversarial)

elif load_from_local:
    print(f"Load data from local repository")
    fever_dataset = load_from_disk(local_dataset)
    adversarial_testset = load_from_disk(local_adversarial)
    
print("Done!")

print(f"Train ssplit length: {len(fever_dataset['train'])}")
max_id = max(fever_dataset['train']['id'])
print(f"Max id in train split: {max_id}")

Load data from local repository
Done!
Train ssplit length: 51086
Max id in train split: 99998


## Exploration

In [6]:
#@title Tokenization function for datapoint visualization

if test_tokenizer:
        
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    label_map = {
        'ENTAILMENT': 0,
        'NEUTRAL': 1,
        'CONTRADICTION': 2,
        'NOT ENOUGH INFO': None
    }

    def tokenize_function(examples):
        examples['label'] = [label_map[label] for label in examples['label']]
        return tokenizer(examples['premise'], examples['hypothesis'], padding=True, truncation=True)


    # tokenized = fever_dataset.map(tokenize_function, batched=True)

In [7]:
#@title Exploration utils

def pretty_print_dict(d, indent=0):
    for key, value in d.items():
        print(' ' * indent + str(key) + ':', end=' ')
        if isinstance(value, dict):
            print()
            pretty_print_dict(value, indent + 4)
        else:
            print(value)


def plot_labels_distribution(target_set, title=''):

    labels = [
        'ENTAILMENT',
        'NEUTRAL',
        'CONTRADICTION',
    ]

    label_counts = {}
    for label in target_set['label']:
        if label not in label_counts:
            label_counts[label] = 0
        label_counts[label] += 1

    plt.bar(labels, label_counts.values())
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.title(f'Distribution of labels in {title}')
    plt.show()
    print()
        

def plot_lengths_distribution(target_set, title='', compare_length=False):
    # extract premises and hypotheses
    premises = [item['premise'] for item in target_set]
    hypotheses = [item['hypothesis'] for item in target_set]

    # compute lengths
    premise_lengths = [len(premise.split()) for premise in premises]
    hypothesis_lengths = [len(hypothesis.split()) for hypothesis in hypotheses]

    # plotting length distributions
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    sns.histplot(premise_lengths, bins=50, kde=True, ax=axes[0], color='blue', log_scale=(False, True))
    axes[0].set_title('Premise Length Distribution')
    axes[0].set_xlabel('Number of Words')
    axes[0].set_ylabel('Frequency')

    sns.histplot(hypothesis_lengths, bins=50, kde=True, ax=axes[1], color='green', log_scale=(False, True))
    axes[1].set_title('Hypothesis Length Distribution')
    axes[1].set_xlabel('Number of Words')
    axes[1].set_ylabel('Frequency')

    fig.suptitle(f'Premise and Hypothesis Length Distribution in {title}')
    plt.show()

    if compare_length:
        # plot premise vs hypothesis length scatter plot
        plt.figure(figsize=(8, 6))
        plt.scatter(premise_lengths, hypothesis_lengths, alpha=0.5, s=1)
        plt.title('Premise vs Hypothesis Length')
        plt.xlabel('Premise Length')
        plt.ylabel('Hypothesis Length')
        plt.yscale('log')
        plt.xscale('log')
        plt.grid(True, which="both", ls="--")
        plt.show()

In [8]:
#@title Statistics about the regular dataset
if print_statistics:
  plot_labels_distribution(fever_dataset['train'], title="Fever Train Set")
  plot_lengths_distribution(fever_dataset['train'], title="Fever Train Set", compare_length=True)

In [9]:
#@title Dataset structure

if general_structure:
  print(f"Datasets structure: {fever_dataset}")

if show_ids:
  print(f"Train IDs: {fever_dataset['train']['id']}")

if show_labels:
  print(f"Train labels: {fever_dataset['train']['label']}")


if wsd_exploration:

  print(f"\nWSD srtructure: ")
  pretty_print_dict(fever_dataset['train'][42]['wsd'])

  sample_range = len(fever_dataset['train'])
  loop = tqdm(range(sample_range))

  wsd_info = dict()
  wsd_info['pos'] = dict()

  for i in loop:

    data =  fever_dataset['train'][i]
    sample_id = data['id']

    hypothesis = data['hypothesis']
    hyp_wsd = data['wsd']['hypothesis']

    for hyp_wsd_dict in hyp_wsd:
      
      pos  = hyp_wsd_dict['pos']
      wsd_info['pos'][pos] = 1 if pos not in wsd_info['pos'] else wsd_info['pos'][pos] + 1

  print("POS:")
  pretty_print_dict(wsd_info)

  with open(info + "wsd_pos.txt", "w") as wsd_pos:
    print(f"Saving allPOS of the dataset into {info + 'wsd_pos.txt'}")
    for elem in wsd_info['pos']: wsd_pos.write(f"{elem}\n")



if srl_exploration:
  print("\nSRL structure: ")
  print(fever_dataset['train'][42]['srl'].keys())
  # print(fever_dataset['train'][42]['srl']['hypothesis']['tokens'])
  # print(fever_dataset['train'][42]['srl']['hypothesis']['annotations'])

  print("Tokens: ")
  for token in fever_dataset['train'][42]['srl']['premise']['tokens']:
      print(token)

  print("\nAnnotations: ")
  for annotation in fever_dataset['train'][42]['srl']['premise']['annotations']:
      for key, value in annotation.items():
          print(f"{key}:\t{value}")
      print()

Datasets structure: DatasetDict({
    train: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 51086
    })
    validation: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2288
    })
    test: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2287
    })
})

SRL structure: 
dict_keys(['premise', 'hypothesis'])
Tokens: 
{'index': 0, 'rawText': 'Deadpool'}
{'index': 1, 'rawText': '('}
{'index': 2, 'rawText': 'film'}
{'index': 3, 'rawText': ')'}
{'index': 4, 'rawText': '.'}
{'index': 5, 'rawText': 'It'}
{'index': 6, 'rawText': 'is'}
{'index': 7, 'rawText': 'the'}
{'index': 8, 'rawText': 'eighth'}
{'index': 9, 'rawText': 'installment'}
{'index': 10, 'rawText': 'and'}
{'index': 11, 'rawText': 'a'}
{'index': 12, 'rawText': 'spin'}
{'index': 13, 'rawText': '-'}
{'index': 14, 'rawText': 'off'}
{'index': 15, 'rawText': 'in'}
{'index': 16,

In [10]:
#@title SRL exploration: collect frame-names and roles 

def get_srl_info(dataset):

  sample_range = len(dataset)
  loop = tqdm(range(sample_range))

  verbs_freqs = dict()
  set_of_va_frames = dict() # verb atlas frames
  set_of_pb_frames = dict() # prop bank frames
  set_of_va_roles = set() # verb atlas frames
  set_of_pb_roles = set() # prop bank frames

  for i in loop:

    data =  dataset[i]
    sample_id = data['id']
    
    tokens = data['srl']['premise']['tokens']
    annotations = data['srl']['premise']['annotations']

    for annotation in annotations:

      token_index = annotation['tokenIndex']
      verb = tokens[token_index]['rawText']
      verbs_freqs[verb] =  1 if verb not in  verbs_freqs else verbs_freqs[verb] + 1

      vb_frame = annotation['verbatlas']['frameName']
      pb_frame = annotation['englishPropbank']['frameName']

      set_of_va_frames[vb_frame] = 1 if vb_frame not in set_of_va_frames else set_of_va_frames[vb_frame] + 1 
      set_of_pb_frames[pb_frame] = 1 if pb_frame not in set_of_pb_frames else set_of_pb_frames[pb_frame] + 1 

      va_roles = annotation['verbatlas']['roles']
      pb_roles = annotation['englishPropbank']['roles']

      for role in va_roles:
        set_of_va_roles.add(role['role'])

      for role in pb_roles:
        set_of_pb_roles.add(role['role'])

  return set_of_va_frames, set_of_pb_frames, set_of_va_roles, set_of_pb_roles, verbs_freqs
  

In [11]:
if save_info:

  set_of_va_frames, set_of_pb_frames, set_of_va_roles, set_of_pb_roles, verbs_freqs = get_srl_info(fever_dataset['train'])

  verbs_freqs = sorted(verbs_freqs.items(), key=lambda item: item[1], reverse=True)
  set_of_pb_frames = sorted(set_of_pb_frames.items(), key=lambda item: item[1], reverse=True)
  set_of_va_frames = sorted(set_of_va_frames.items(), key=lambda item: item[1], reverse=True)

  print(f"Number of Verb Atlas frames: {len(set_of_va_frames)}")
  print(f"Number of Verb Atlas roles: {len(set_of_va_roles)}")

  print(f"Number of Propbank frames: {len(set_of_pb_frames)}")
  print(f"Number of Propbank roles: {len(set_of_pb_roles)}")

  with open(info + "va_roles.txt", "w") as va_roles:
    print(f"Saving all Verb Atlas roles of the dataset into {info + 'va_roles.txt'}")
    for elem in set_of_va_roles: va_roles.write(f"{elem}\n")

  with open(info + "va_frames.txt", "w") as va_frames:
    print(f"Saving all Verb Atlas frames of the dataset into {info + 'va_frames.txt'}")
    for elem in set_of_va_frames: va_frames.write(f"{elem}\n")


  with open(info + "pb_roles.txt", "w") as pb_roles:
    print(f"Saving all Propbank roles of the dataset into {info + 'pb_roles.txt'}")
    for elem in set_of_pb_roles: pb_roles.write(f"{elem}\n")


  with open(info + "pb_frames.txt", "w") as pb_frames:
    print(f"Saving all Propbank frames of the dataset into {info + 'pb_frames.txt'}")
    for elem in set_of_pb_frames: pb_frames.write(f"{elem}\n")

  with open(info + "verbs_freqs.txt", "w") as verbs:
    print(f"Saving all verbs frequencies count into {info + 'verbs_freqs.txt'}")
    for elem in verbs_freqs: verbs.write(f"{elem}\n")


## Augmentation
The following prompt can be used to correct the grammar of a modified hypotesis so if just a not is added the phrase is transformed in negative form:

"Correct the grammar in the following inputs, rephrase the input if necessary to make them more accurate. Provide just the correct version, no explanation."

Ask for:
symmetric relation --> Marry


In [23]:
#@title Augmentation utils

def get_sentence_from_span(tokens, span_begin, span_end):
    sentence = ""
    try: 
      sentence = " ".join([tokens[index]['rawText']  for index in range(span_begin, span_end + 1)]) # if tokens[index]['rawText'] in names])
    except:
      sentence = " ".join([tokens[index]['rawText'] for index in range(span_begin, span_end)]) #if tokens[index]['rawText'] in names])

    return sentence


def join_strings_smartly(words):
    """ Joins a list of words smartly:
    - Adds spaces between words when appropriate.
    - Avoids adding spaces before punctuation.
    """
    punctuation = {'.', ',', ';', ':', '!', '?',')'}
    result = words[0]
    prev = result

    for word in words[1:]:
      if word in punctuation or \
        ("'" in prev) or \
        word.startswith("'") or \
        ("." in prev and "." in word) or \
        ("(" in prev) :
        # add word without space before
        result += word
      else:
        # add with space before
        result += " " + word
      # keep track of previous word  
      prev = word

    return result

def get_synset_from_id(synset_id):
    if synset_id == 'O':
        return None
    try:
        offset = int(''.join(filter(str.isdigit, synset_id)))
        pos = synset_id[-1]
        synset = wn.synset_from_pos_and_offset(pos, offset)
        return synset
    except:
        print("exception")
        return None
    

def get_related_word(synset, pos): 
    info = dict()

    # Map POS tags to WordNet POS tags
    pos_map = {
        'NOUN': wn.NOUN,
        # 'VERB': wn.VERB,
        'ADJ': wn.ADJ,
        'ADV': wn.ADV
    }
    
    if pos not in pos_map:
        return None 
    
    # get hypernyms
    hypernyms = synset.hypernyms()
    if not hypernyms:
        return None
    hypernyms = synset.hypernyms()
    # hypernym_words = set()
    # for hypernym in hypernyms:
    #     hypernym_words.update(hypernym.lemma_names())
    info['hypernyms'] = [hypernyms[0].lemma_names()[0]] #list(hypernym_words)

    # get synonyms 
    synonyms = synset.lemmas()
    if not synonyms:
        return None
    info['synonyms'] = [synonym.name() for synonym in synonyms if synonym.name() != synset.lemmas()[0].name()]

    return info


def extract_names(wsd_data):
    names = []
    current_name = []

    for entry in wsd_data:
        if entry['pos'] == 'PROPN':
            current_name.append(entry['text'])
        else:
            if current_name:
                names.append(' '.join(current_name))
                current_name = []

    # catch any remaining name at the end
    if current_name:
        names.append(' '.join(current_name))
    
    return names


def extract_partial_match_name(text, name_list):
    # tokenize the input text
    words = re.findall(r'\b\w+\b', text)
    names = []
    
    # iterate through each name in the name list
    for name in name_list:
        name_parts = name.split()

        # check for partial match in the tokenized words
        for i in range(len(words) - len(name_parts) + 1):
            if words[i:i + len(name_parts)] == name_parts[:len(words[i:i + len(name_parts)])]:
              names.append(name)
              
    if names: return names
    return None


def extract_dates(text):
    # regular expression patterns for different date formats
    patterns = [
        r'\b(\d{4})\b',  # matches a 4-digit year
        r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b',  # matches dates like dd-mm-yyyy, dd/mm/yyyy, dd-mm-yy, dd/mm/yy
        r'\b(\d{1,2} [A-Za-z]+ \d{4})\b',  # matches dates like 1 January 2020
        r'\b([A-Za-z]+ \d{1,2}, \d{4})\b'  # matches dates like January 1, 2020
    ]

    months_list = [
        "January", "February", "March", "April", "May", "June",
        "July", "August", "September", "October", "November", "December"
    ]

    pattern = re.compile(r'(?:(?P<day>\d{1,2})(?:st|nd|rd|th)?[ ,]*)?(?:(?P<month>[A-Za-z]+)[ ,]*)?(?:(?P<year>\d{4}))?')
    matches = pattern.findall(text)
    day, month, year = None, None, None

    for match in matches:
        if match[0]:
            day = match[0]
        if match[1]:
            month = match[1]
        if match[2]:
            year = match[2]
        
        # more check to avoid error!
        if day and not (1 <= int(day) <= 31):
            day = None
        if month and month not in months_list:
            month = None
            
    return day, month, year

def extract_locations(sentence):
    locations = set()
    entities =  nlp_spacy(sentence).ents
    if entities:
        found_target = False
        for ent in entities:
            if ent.label_ == "GPE":
                found_target = True
                locations.add(ent.text)
    return locations


movie_titles = [
    "John Wick: Chapter 2",
    "On the Road (film)",
    "Brave",
    "Penny Dreadful",
    "Snooki & Jwoww",
    "Sons of Anarchy",
    "The Sopranos",
    "Thor: The Dark World",
    "Winter Passing",
    "Teen Wolf",
    "Captain America: The Winter Soldier",
    "The Belko Experiment",
    "Fantastic Beasts and Where to Find Them",
    "A Monster Calls",
    "The Fate of the Furious",
    "The Wolf of Wall Street (2013 film)",
    "Tropico",
    "Rescue Me",
    "The Night Of",
    "Lipstick Under My Burkha",
    "The Promise",
    "Sleeping Beauty",
    "Mad Men",
    "Avatar",
    "Zootopia",
    "Spider-Man",
    "Enemy",
    "Room",
    "Cloud Atlas",
    "Kong: Skull Island",
    "Rick and Morty",
    "Ink Master",
    "Frenemies",
    "Persuasion (2007 film)",
    "Ballet Shoes",
    "The Great Buck Howard",
    "Schindler's List",
    "Iron Man",
    "The Illusionist",
    "The Messenger",
    "The Suite Life Movie",
    "Wild",
    "The Ren & Stimpy Show",
    "Johnny Mnemonic",
    "Denial (2016 film)",
    "Black Sails",
    "The Breakfast Club",
    "Modern Family",
    "Interstellar",
    "To the Bone",
    "Prison Break",
    "Game of Thrones (season 3)",
    "Line of Duty",
    "In the Heart of the Sea",
    "Oz the Great and Powerful",
    "Attack on Titan",
    "Her",
    "The Carmichael Show",
    "The Leftovers",
    "Short Term 12",
    "Stephanie Daley",
    "Fargo",
    "BoJack Horseman",
    "New Girl",
    "Glee",
    "Turn: Washington's Spies",
    "X-Men: Days of Future Past",
    "Miss Peregrine's Home for Peculiar Children",
    "Ghostbusters",
    "Famous in Love",
    "Legion",
    "Spider-Man 3",
    "Deadpool (film)",
    "Doctor Who",
    "San Junipero",
    "Outlander (TV series)",
    "Elementary",
    "The Shield",
    "Broadchurch",
    "Fairy Tail",
    "Major Barbara",
    "American Horror Story",
    "Trolls",
    "Harry Potter",
    "Whiplash",
    "Split (2016 American film)",
    "How to Be",
    "Cars Toons",
    "Little Miss Sunshine",
    "Naruto",
    "Horrible Bosses",
    "There Will Be Blood",
    "Beauty and the Beast",
    "Grey's Anatomy",
    "Futurama",
    "The Strain",
    "The Avengers",
    "The Vampire Diaries",
    "True Detective",
    "Teen Wolf's sixth season",
    "Interview",
    "The Supernatural",
    "The Girl on the Train",
    "Interstellar",
    "Guardians of the Galaxy",
    "Miss Peregrine's Home for Peculiar Children",
    "Love & Friendship",
    "Goliyon Ki Raasleela Ram-Leela"
]

def get_movie_title(sentence):
    for title in movie_titles:
        if re.search(r'\b' + re.escape(title) + r'\b', sentence):
            return title

# add new samples to the dataset
# filtered_dataset = fever_dataset['train'].filter(lambda example: 'is a ' in example['premise'].lower()) # or 'is a' in example['hypothesis'].lower())

In [24]:
#@title Graph utils

def get_sentences_by_id(dataset, id):
  index = dataset['id'].index(str(id))
  print(dataset[index]['premise'])
  print(dataset[index]['hypothesis'])


def format_date(dates):
  select_date = None
  for date in dates:
    if not date: continue
    if not date[2]: continue
    if date[2]: select_date = date[2]
    if date[1] and date[2]: select_date = date[1] + ", " + date[2]
    if date[0] and date[1] and date[2]:
      select_date = date[1] + " " + date[0] + ", " + date[2]
      break
  return select_date


def get_all_locations(graph, frame):
  ...
  all_locations = set()
  for key in graph[frame].keys():
    for location in graph[frame][key]['locations']:
      all_locations.add(location)
  return list(all_locations)


def get_all_dates(graph, frame):
  ...
  all_dates = set()
  for key in graph[frame].keys():
    date = format_date( graph[frame][key]['date'])
    all_dates.add(date)
  return list(all_dates)


location_hierarchy = {
   "United States": [
      "California", 
      "Texas", 
      "Illinois", 
      "District of Columbia",
      "New York",
      "New York City",
      "Washington",
      "D.C.",
      "Chicago",
      "Austin",
      "Hollywood",
      "Los Angeles"
    ],
   "the United States": [
      "California", 
      "Texas", 
      "Illinois", 
      "District of Columbia",
      "New York",
      "New York City",
      "Washington",
      "D.C.",
      "Chicago",
      "Austin",
      "Hollywood",
      "Los Angeles"
    ],
   "US": [
      "California", 
      "Texas", 
      "Illinois", 
      "District of Columbia",
      "New York",
      "New York City",
      "Washington",
      "D.C.",
      "Chicago",
      "Austin",
      "Hollywood",
      "Los Angeles"
    ],
    "California": ["Hollywood", "Los Angeles"],
    "Texas": ["Austin"],
    "Illinois": ["Chicago"],
    "District of Columbia": ["Washington", "D.C."],
    "New York": ["New York City"],
    "Canada": ["Alberta", "Edmonton"],
    "Japan": ["Tokyo"],
    "Australia":  ["Sydney"],
    "Germany": ["Berlin"],
    "Belgium": ["Belgium"],
    "Philippines": ["Philippines"],
    "United Kingdom": ["London"]
}

def is_contained(location, container, location_hierarchy):
  
  if container in location_hierarchy:
    if location in location_hierarchy[container]: 
      return True

  return False

def get_random_exclusive_element(main_list, exclusion_list):
  exclusion_set = set(exclusion_list)
  filtered_list = [item for item in main_list if item not in exclusion_set]
  
  if not filtered_list: return None  
  choice = random.choice(filtered_list)
  check = True
  while check:
    choice = random.choice(filtered_list)
    for location in exclusion_list:
      # print("choice ", choice)
      # print("location ", location)
      # print("is contained ", is_contained(choice, location, location_hierarchy))
      # print("contains ", is_contained(location, choice, location_hierarchy))
      # print()
      related = is_contained(choice, location, location_hierarchy) or is_contained(location, choice, location_hierarchy) or choice == 'Netflix' or choice == 'Spike'
      if not related: 
        check = False
        break
    
  return choice

# print(is_contained("the United States", "Tokyo", location_hierarchy))  # True
# print(is_contained("Tokyo", "the United States", location_hierarchy))  # True

def get_all_spouses(graph, frame):
  ...
  all_spouses = set()
  for key in graph[frame].keys():
    for spouse in graph[frame][key]['spouses']:
      all_spouses.add(spouse)
  return list(all_spouses)



### subsection

In [25]:
#@title Build relational graph

# dataset = filtered_dataset
dataset = fever_dataset['train']

sample_range = len(dataset)
loop = tqdm(range(sample_range))

accetable_verbs = ['marry.01', 'premiere.01']

relational_graph = dict()
old_max_span = 1_000_000

for i in loop:

  data =  dataset[i]
  sample_id = data['id']
  premise = data['premise']
  hypothesis = data['hypothesis']
  # print(f"premise: {premise}")
  
  # srl info
  tokens = data['srl']['premise']['tokens']
  annotations = data['srl']['premise']['annotations']

  # wsd info
  wsd = data['wsd']['premise']
  proper_nouns = extract_names(wsd)

  hp_wsd = data['wsd']['hypothesis']
  hp_proper_nouns = extract_names(hp_wsd)

  # hp_wsd = data['wsd']['hypothesis']
  # hp_proper_nouns = extract_names(hp_wsd)
  # print(proper_nouns)

  for annotation in annotations:

    token_index = annotation['tokenIndex']
    verb = tokens[token_index]['rawText']

    # frame = annotation['verbatlas']['frameName']
    frame = annotation['englishPropbank']['frameName']

    if verb[0].isupper(): continue  # because usually capital letter verbs are movie titles
    if frame not in accetable_verbs: continue
    if frame not in relational_graph: relational_graph[frame] = dict()
    # print(premise)

    # roles = annotation['verbatlas']['roles']
    roles = annotation['englishPropbank']['roles']

    try:
      span_begin = roles[0]['span'][0]
      span_end = roles[-1]['span'][1]

    except:
      # print(roles)
      continue

    if span_begin > 0 and span_begin < old_max_span: 
      span_begin = 0
      old_max_span = span_end

    sentence = get_sentence_from_span(tokens, span_begin, span_end)


    if frame in ['premiere.01'] and ("premiered in " in premise or "premiered at " in premise or "premiered on " in premise):

      # from sentence extract movie title
  
      title = get_movie_title(hypothesis)
      if not title: continue

      # print(sentence)
      span_1_l, span_2_l, span_1_d, span_2_d  = None, None, None, None
      for role in roles:
        if role['span'][1] > span_end : 
          # print(sentence)
          # print(span_end)
          # print(role)
          break

        if role['role'] in ['ARGM-LOC', 'R-ARGM-LOC', 'C-ARGM-LOC']: 
          span_1_l = role['span'][0]
          span_2_l = role['span'][1]

        if role['role'] in ['ARGM-TMP', 'R-ARGM-TMP']: 
          span_1_d = role['span'][0]
          span_2_d = role['span'][1] 

      # from sentence extrat location and date
      if  span_1_l and span_2_l:
        location_sentence = get_sentence_from_span(tokens, span_1_l, span_2_l)
        locations = extract_locations(location_sentence)
      else: 
        locations = set()
      
      locations = list(locations)
    
      if span_1_d and span_2_d:
        date_sentence =  get_sentence_from_span(tokens, span_1_d, span_2_d)
        date = extract_dates(date_sentence)
      else:
        date = []
        
      # add nodes to the graph
      if title not in relational_graph[frame]: 
        relational_graph[frame][title] = dict()
        relational_graph[frame][title]['locations'] = locations
        relational_graph[frame][title]['date'] = [date]
        relational_graph[frame][title]['id'] = [sample_id]

      else: 
        if set(locations).intersection(set(relational_graph[frame][title]['locations'])) == set(): 
          relational_graph[frame][title]['locations'] += locations

        if date not in relational_graph[frame][title]['date']:
          relational_graph[frame][title]['date'].append(date)

        if sample_id not in relational_graph[frame][title]['id']:
          relational_graph[frame][title]['id'].append(sample_id)


    if frame in  ['marry.01']: 
      
      # print(sentence)
      # print(hypothesis)
      main = extract_partial_match_name(hypothesis, hp_proper_nouns)
      targets = extract_partial_match_name(sentence, proper_nouns)

      if main: main = main[0]
      # print(hp_proper_nouns)
      # print(main)

      if not targets: continue
      targets = list(set(targets))

      date = []
      if "married in" in premise:
        span_1_d, span_2_d  = None, None
        for role in roles:
          if role['span'][1] > span_end : break

          if role['role'] in ['ARGM-TMP', 'R-ARGM-TMP']: 
            span_1_d = role['span'][0]
            span_2_d = role['span'][1] 

        if span_1_d and span_2_d:
          date_sentence =  get_sentence_from_span(tokens, span_1_d, span_2_d)
          date = extract_dates(date_sentence)

      if main in targets: 
        targets.remove(main)

      if main not in relational_graph[frame]:
         relational_graph[frame][main] = dict()
         relational_graph[frame][main]['spouses'] = targets
         relational_graph[frame][main]['date'] = [date]
         relational_graph[frame][main]['id'] = [sample_id]

      else:
      
        if targets not in relational_graph[frame][main]['spouses']:
          relational_graph[frame][main]['date'].append(date)

        if date not in relational_graph[frame][main]['date']:
          relational_graph[frame][main]['date'].append(date)

        if sample_id not in relational_graph[frame][main]['id']:
          relational_graph[frame][main]['id'].append(sample_id)


100%|██████████| 51086/51086 [01:18<00:00, 651.12it/s]


In [26]:
#@title Save/Load relational graph

if save_graph:
    with open(f"{graph}/relational_graph.json", "w") as fp:
        json.dump(relational_graph, fp, indent=4)

if load_graph:
    with open(f"{graph}/relational_graph.json", "r") as f:
        relational_graph = json.load(f)


In [27]:
pretty_print_dict(relational_graph)

premiere.01: 
    John Wick: Chapter 2: 
        locations: ['Los Angeles']
        date: [['30', 'January', '2017']]
        id: ['138117', '48551', '9366', '138368', '61964', '97247', '9367']
    Brave: 
        locations: []
        date: [['10', 'June', '2012']]
        id: ['209148', '209129', '209162']
    Penny Dreadful: 
        locations: []
        date: [['9', None, None], ['11', 'May', '2014']]
        id: ['113668', '88185', '145475', '33511', '135923', '40973', '136982', '63647', '53104', '149528', '147369']
    Snooki & Jwoww: 
        locations: []
        date: [['22', 'October', '2013'], ['5', 'November', '2014'], ['8', 'January', '2013']]
        id: ['183071', '183054', '183052', '183061', '183044', '183068', '183064', '183060']
    Sons of Anarchy: 
        locations: []
        date: [['3', 'September', '2008'], ['9', 'September', '2014'], ['9', 'December', '2014']]
        id: ['57', '13852', '108672', '73620', '99584', '51560', '58', '61340', '116506', '82239']


In [29]:
#@title Relational graph augmentation

# dataset = filtered_dataset
dataset = fever_dataset['train']
new_hypotesys = []

curr_id = max_id

all_movie_locations = get_all_locations(relational_graph, "premiere.01")
# print(f"All movies locations: {all_movie_locations}")
all_movie_dates = get_all_dates(relational_graph, "premiere.01")
# print(f"All movies dates: {all_movie_dates}")


for frame in relational_graph.keys():

  if frame == "premiere.01":

    # synset = get_synset_from_id('1718331v')
    contraddictions = []
    neutrals = []

    # go over all the titles in the graph
    for title, infos in relational_graph[frame].items():
      locations = relational_graph[frame][title]['locations']
      dates =  relational_graph[frame][title]['date']
      ids =  relational_graph[frame][title]['id']

      if locations: select_location = locations[random.randint(0, len(locations)-1)]
      # print(dates)
      date = format_date(dates)
      select_date = date

      contraddiction = title + " "
      if len(locations) >= 1 and date:
        ... # choose randomly what to change
        # print(title)

        if random.randint(1, 100) <= 33: 
          # change date by randomly adding or subtracting 1 to 10 years
          sign = 1 if random.randint(1, 100) <= 50 else -1
          new_date = date[:-4] + str(int(date[-4:]) + (sign * random.randint(1, 5)))
          # print(date)
          # print(new_date)
          select_date = new_date

        elif 33 <= random.randint(1, 100) <= 66:
          # change location
          new_location = get_random_exclusive_element(all_movie_locations, locations)
          # print(f"curr locations: {locations}")
          # print(f"new location = {new_location}\n")
          select_location = new_location

        else:
          # change both
          sign = 1 if random.randint(1, 100) <= 50 else -1
          select_date = date[:-4] + str(int(date[-4:]) + (sign * random.randint(1, 5)))
          select_location = get_random_exclusive_element(all_movie_locations, locations)

        contraddiction += "premiered in " + select_location + " on " + select_date + "."


      elif len(locations) >= 1 and not date:
        ... # change only location
        new_location = get_random_exclusive_element(all_movie_locations, locations)
        contraddiction += "premiered in " + new_location + "."

      elif len(locations) <= 1 and date:
        ... # change only date
        sign = 1 if random.randint(1, 100) <= 50 else -1
        new_date = date[:-4] + str(int(date[-4:]) + (sign * random.randint(1, 5)))
        contraddiction += "premiered on " + new_date + "."

      else:
        ... # nothing to do 
        continue

      if contraddiction: contraddictions.append(contraddiction)

      neutral = ""
      sign = 1 if random.randint(1, 100) <= 50 else -1
      new_date = None
      if date: new_date = date[:-4] + str(int(date[-4:]) + random.randint(1, 2))
      if not locations: locations = ['unk']

      new_location = get_random_exclusive_element(all_movie_locations, locations)

      if new_date and new_location:
        neutral = title + " was also released at " + new_location + " on " + new_date
      else:
        neutral = title + " was also released at " + new_location

      if neutral: neutrals.append(neutral)

print(f"\nContraddictions {len(contraddictions)}")
print("----------------------")
for sentence in contraddictions:
  print(sentence)

print(f"\nNeutrals {len(neutrals)}")
print("----------------------")
for sentence in neutrals:
  print(sentence)


Contraddictions 88
----------------------
John Wick: Chapter 2 premiered in New York City on January 30, 2022.
Brave premiered on June 10, 2015.
Penny Dreadful premiered on May 11, 2012.
Snooki & Jwoww premiered on October 22, 2008.
Sons of Anarchy premiered on September 3, 2013.
The Sopranos premiered in the United States on January 10, 2003.
Thor: The Dark World premiered in Hollywood on October 22, 2008.
The Ren & Stimpy Show premiered on 1989.
Winter Passing premiered on 2000.
Teen Wolf premiered on November 15, 2018.
Captain America: The Winter Soldier premiered in Los Angeles on March 13, 2009.
The Belko Experiment premiered on 2019.
Fantastic Beasts and Where to Find Them premiered in New York City on 2013.
A Monster Calls premiered on September 10, 2013.
Prison Break premiered on April 4, 2013.
The Strain premiered in Philippines on July 13, 2016.
Tropico premiered on December, 2009.
Rescue Me premiered on 2003.
Lipstick Under My Burkha premiered in Berlin.
The Promise premier

### Spouses

In [None]:
# dataset = filtered_dataset
dataset = fever_dataset['train']
new_hypotesys = []

curr_id = max_id

# print(f"All movies locations: {all_movie_locations}")
all_movie_dates = get_all_dates(relational_graph, "marry.01")
# print(f"All movies dates: {all_movie_dates}")

for frame in relational_graph.keys():

  if frame == "marry.01":

    # synset = get_synset_from_id('1718331v')
    contraddictions = []
    neutrals = []

    # go over all the titles in the graph
    for title, infos in relational_graph[frame].items():
      spouses = relational_graph[frame][title]['spouses']
      dates =  relational_graph[frame][title]['date']
      ids =  relational_graph[frame][title]['id']

      if locations: select_location = locations[random.randint(0, len(locations)-1)]
      # print(dates)
      date = format_date(dates)
      select_date = date

      contraddiction = title + " "
      if len(locations) >= 1 and date:
        ... # choose randomly what to change
        # print(title)

        if random.randint(1, 100) <= 33: 
          # change date by randomly adding or subtracting 1 to 10 years
          sign = 1 if random.randint(1, 100) <= 50 else -1
          new_date = date[:-4] + str(int(date[-4:]) + (sign * random.randint(1, 5)))
          # print(date)
          # print(new_date)
          select_date = new_date

        elif 33 <= random.randint(1, 100) <= 66:
          # change location
          new_location = get_random_exclusive_element(all_movie_locations, locations)
          # print(f"curr locations: {locations}")
          # print(f"new location = {new_location}\n")
          select_location = new_location

        else:
          # change both
          sign = 1 if random.randint(1, 100) <= 50 else -1
          select_date = date[:-4] + str(int(date[-4:]) + (sign * random.randint(1, 5)))
          select_location = get_random_exclusive_element(all_movie_locations, locations)

        contraddiction += "premiered in " + select_location + " on " + select_date + "."


      elif len(locations) >= 1 and not date:
        ... # change only location
        new_location = get_random_exclusive_element(all_movie_locations, locations)
        contraddiction += "premiered in " + new_location + "."

      elif len(locations) <= 1 and date:
        ... # change only date
        sign = 1 if random.randint(1, 100) <= 50 else -1
        new_date = date[:-4] + str(int(date[-4:]) + (sign * random.randint(1, 5)))
        contraddiction += "premiered on " + new_date + "."

      else:
        ... # nothing to do 
        continue

      if contraddiction: contraddictions.append(contraddiction)

      neutral = ""
      sign = 1 if random.randint(1, 100) <= 50 else -1
      new_date = None
      if date: new_date = date[:-4] + str(int(date[-4:]) + random.randint(1, 2))
      if not locations: locations = ['unk']

      new_location = get_random_exclusive_element(all_movie_locations, locations)

      if new_date and new_location:
        neutral = title + " was also released at " + new_location + " on " + new_date
      else:
        neutral = title + " was also released at " + new_location

      if neutral: neutrals.append(neutral)

In [36]:
#@title Augmentation by Synonyms and Hypernyms

if syn_hyp_augment:
  
  sample_range = 10 #len(fever_dataset['train'])
  loop = tqdm(range(sample_range))

  new_samples = {
      'id': [],
      'premise': [],
      'hypothesis': [],
      'label': [],
      'wsd': [None for i in range(sample_range)],
      'srl': [None for i in range(sample_range)]
  }

  syn_dict = dict()
  syn_idx = 0
  progressive_id = int(max_id)

  for i in loop:

    data =  fever_dataset['train'][i]
    sample_id = data['id']

    premise = data['premise']
    hypothesis = data['hypothesis']
    label = data['label']
    hyp_wsd = data['wsd']['hypothesis']

    new_samples['id'].append(str(progressive_id))
    new_samples['premise'].append(premise)
    new_samples['label'].append(label)
    progressive_id += 1

    hypothesis = data['hypothesis']
    new_hypotesys = []

    for hyp_wsd_dict in hyp_wsd:

      # substitute word with 50% probability
      skip = False
      if random.randint(1, 100) <= 50: 
        skip = True

      word = hyp_wsd_dict['text']
      pos = hyp_wsd_dict['pos']
      offset = hyp_wsd_dict['wnSynsetOffset']
      synset = get_synset_from_id(offset)

      related_words = None
      if synset: related_words = get_related_word(synset, pos)

      if related_words and not skip:
        hypernyms = related_words['hypernyms']
        synonyms = related_words['synonyms']

        chosen_subs = synonyms if random.randint(1, 100) <= 75 else hypernyms

        if not chosen_subs : 
          new_hypotesys.append(word)
          continue

        synonym = chosen_subs[syn_idx % (len(chosen_subs))]

        if synonym in syn_dict and syn_dict[synonym] > 10:
          syn_idx += 1
          synonym = chosen_subs[syn_idx % (len(chosen_subs))]

        syn_dict[synonym] = 1 if  synonym not in syn_dict else syn_dict[synonym] + 1

        syn_idx += 1

        if '_' in synonym: synonym = synonym.replace('_', ' ')
        new_hypotesys.append(synonym)

      else: new_hypotesys.append(word)

    if not new_hypotesys: 
      new_samples['hypothesis'].append(hypothesis)
    else:
      new_hypotesys = join_strings_smartly(new_hypotesys)
      new_samples['hypothesis'].append(new_hypotesys)
    # print(f"new: {new_hypotesys} \n")

  # print(new_samples)

100%|██████████| 10/10 [00:00<00:00, 308.51it/s]


In [34]:
new_samples['hypothesis']

['Roman Atwood is a content person.',
 'The Boston Celtics play their home games at TD Garden.',
 'There is a picture called The Hunger Games.',
 'Ryan Seacrest is a person.',
 'Stranger than Fiction is a film.',
 'Selena recorded music.',
 'Selena tape music.',
 'Selena recorded music.',
 'Selena recorded music.',
 'John Wick: Chapter 2 was theatrically released in the Oregon.']

In [85]:
#@title Add new samples to the original dataset
 
augmentation = Dataset.from_dict(new_samples)
print(f"Augmentation type: {type(augmentation)}")

augmented_dataset = concatenate_datasets([fever_dataset['train'], augmentation])
print(f"Augmented dataset type: {type(augmented_dataset)}")
print(f"Train split length: {len(fever_dataset['train'])}")
print(f"Train split augmentated: {len(augmented_dataset)}")

Augmentation type: <class 'datasets.arrow_dataset.Dataset'>
Augmented dataset type: <class 'datasets.arrow_dataset.Dataset'>
Train split length: 51086
Train split augmentated: 51096
