## Global Config

In [2]:
download_from_remote = True
load_from_local = True
save_to_local = True

local_dataset = "./dataset/fever_dataset"
local_adversarial = "./dataset/adversarial_dataset"
%mkdir -p {local_dataset}
%mkdir -p {local_adversarial}

print_statistics = False
save_info = False
info = "./info/"
%mkdir -p {info}

## Imports

In [72]:
# general imports
import os
import re
import string
import random
from tqdm import tqdm

# visualization and statistics
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter

# data manipulation
import numpy as np
import pandas as pd
from typing import Dict
from datasets import(
  Dataset, 
  load_dataset, 
  load_from_disk, 
  concatenate_datasets,
) 

# nltk
import nltk
from nltk.corpus import wordnet as wn
nltk.download('wordnet')
nltk.download('wordnet')

# spacy
import spacy 
# Load the SpaCy model
nlp_spacy = spacy.load("en_core_web_sm")

# huggingface
from transformers import DataCollatorWithPadding
from transformers import (
    AutoTokenizer,
    set_seed,
)

# set seeds
set_seed(42)
np.random.seed(42)
random.seed(42)

[nltk_data] Downloading package wordnet to /home/leeoos/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package wordnet to /home/leeoos/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Data

In [83]:
# download data or just load from local 

download_from_remote = not(os.path.exists(local_dataset) and os.path.exists(local_adversarial))

if download_from_remote:
    print("Downloading data from remote repository")

    # load chunk of FEVER datyaset
    fever_dataset = load_dataset("tommasobonomo/sem_augmented_fever_nli", trust_remote_code=True)

    # load adversarial 
    adversarial_testset = load_dataset("iperbole/adversarial_fever_nli", trust_remote_code=True)

    # structure of the dataset
    print(fever_dataset)

    if save_to_local:
        print(f"Save data in local {local_dataset}")
        fever_dataset.save_to_disk(local_dataset)
        print(f"Save adversarial dataset in {local_adversarial}")
        adversarial_testset.save_to_disk(local_adversarial)

elif load_from_local:
    print(f"Load data from local repository")
    fever_dataset = load_from_disk(local_dataset)
    adversarial_testset = load_from_disk(local_adversarial)
    
print("Done!")

print(f"Train ssplit length: {len(fever_dataset['train'])}")
max_id = max(fever_dataset['train']['id'])
print(f"Max id in train split: {max_id}")

Load data from local repository
Done!
Train ssplit length: 51086
Max id in train split: 99998


In [68]:
fever_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 51086
    })
    validation: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2288
    })
    test: Dataset({
        features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
        num_rows: 2287
    })
})

## Exploration

In [5]:
#@title Tokenization function for datapoint visualization

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

label_map = {
    'ENTAILMENT': 0,
    'NEUTRAL': 1,
    'CONTRADICTION': 2,
    'NOT ENOUGH INFO': None
}

def tokenize_function(examples):
    examples['label'] = [label_map[label] for label in examples['label']]
    return tokenizer(examples['premise'], examples['hypothesis'], padding=True, truncation=True)

In [6]:
#@title Exploration utils

def plot_labels_distribution(target_set, title=''):

    labels = [
        'ENTAILMENT',
        'NEUTRAL',
        'CONTRADICTION',
    ]

    label_counts = {}
    for label in target_set['label']:
        if label not in label_counts:
            label_counts[label] = 0
        label_counts[label] += 1

    plt.bar(labels, label_counts.values())
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.title(f'Distribution of labels in {title}')
    plt.show()
    print()
        

def plot_lengths_distribution(target_set, title='', compare_length=False):
    # Extract premises and hypotheses
    premises = [item['premise'] for item in target_set]
    hypotheses = [item['hypothesis'] for item in target_set]

    # Compute lengths
    premise_lengths = [len(premise.split()) for premise in premises]
    hypothesis_lengths = [len(hypothesis.split()) for hypothesis in hypotheses]

    # Plotting length distributions
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    sns.histplot(premise_lengths, bins=50, kde=True, ax=axes[0], color='blue', log_scale=(False, True))
    axes[0].set_title('Premise Length Distribution')
    axes[0].set_xlabel('Number of Words')
    axes[0].set_ylabel('Frequency')

    sns.histplot(hypothesis_lengths, bins=50, kde=True, ax=axes[1], color='green', log_scale=(False, True))
    axes[1].set_title('Hypothesis Length Distribution')
    axes[1].set_xlabel('Number of Words')
    axes[1].set_ylabel('Frequency')

    fig.suptitle(f'Premise and Hypothesis Length Distribution in {title}')
    plt.show()

    if compare_length:
        # Plot premise vs hypothesis length scatter plot
        plt.figure(figsize=(8, 6))
        plt.scatter(premise_lengths, hypothesis_lengths, alpha=0.5, s=1)
        plt.title('Premise vs Hypothesis Length')
        plt.xlabel('Premise Length')
        plt.ylabel('Hypothesis Length')
        plt.yscale('log')
        plt.xscale('log')
        plt.grid(True, which="both", ls="--")
        plt.show()


def plot_vocb_distribution():
    # Tokenize words
    # premise_words = [word for text in premises for word in word_tokenize(text)]
    # hypothesis_words = [word for text in hypotheses for word in word_tokenize(text)]

    # # Compute word frequencies
    # premise_word_freq = Counter(premise_words)
    # hypothesis_word_freq = Counter(hypothesis_words)

    # # Plotting the most common words
    # premise_common_words = premise_word_freq.most_common(20)
    # hypothesis_common_words = hypothesis_word_freq.most_common(20)

    # fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    # axes[0].barh([word[0] for word in premise_common_words], [word[1] for word in premise_common_words])
    # axes[0].set_title('Premise Common Words')
    # axes[1].barh([word[0] for word in hypothesis_common_words], [word[1] for word in hypothesis_common_words])
    # axes[1].set_title('Hypothesis Common Words')
    # plt.show()
    ...

# Prepare data for plotting
# lengths_and_labels = [(length, item['label']) for length, item in zip(premise_lengths, train_dataset)]
# premise_df = pd.DataFrame(lengths_and_labels, columns=['length', 'label'])

# lengths_and_labels = [(length, item['label']) for length, item in zip(hypothesis_lengths, train_dataset)]
# hypothesis_df = pd.DataFrame(lengths_and_labels, columns=['length', 'label'])

# # Plot premise lengths across classes
# plt.figure(figsize=(12, 6))
# sns.boxplot(x='label', y='length', data=premise_df)
# plt.title('Premise Length Distribution Across Classes')
# plt.show()

# # Plot hypothesis lengths across classes
# plt.figure(figsize=(12, 6))
# sns.boxplot(x='label', y='length', data=hypothesis_df)
# plt.title('Hypothesis Length Distribution Across Classes')
# plt.show()

# from sklearn.feature_extraction.text import TfidfVectorizer
# from sklearn.metrics.pairwise import cosine_similarity
# import numpy as np

# # Create combined list of premises and hypotheses
# texts = premises + hypotheses

# # Compute TF-IDF vectors
# tfidf = TfidfVectorizer().fit_transform(texts)

# # Compute cosine similarity between each premise and its corresponding hypothesis
# cosine_sim = [cosine_similarity(tfidf[i], tfidf[i + len(premises)])[0][0] for i in range(len(premises))]

# # Plot cosine similarity distribution
# plt.figure(figsize=(8, 6))
# plt.hist(cosine_sim, bins=50, alpha=0.5, color='purple')
# plt.title('Cosine Similarity Distribution Between Premise and Hypothesis')
# plt.xlabel('Cosine Similarity')
# plt.ylabel('Frequency')
# plt.show()


In [117]:
# statistics about the regular dataset
if print_statistics:
  plot_labels_distribution(fever_dataset['train'], title="Fever Train Set")
  plot_lengths_distribution(fever_dataset['train'], title="Fever Train Set", compare_length=True)

In [118]:
# print(fever_dataset['train'][0]['srl'].keys())
# print(fever_dataset['train'][0]['srl']['hypothesis']['tokens'])
# print(fever_dataset['train'][0]['srl']['hypothesis']['annotations'])

# for token in fever_dataset['train'][0]['srl']['premise']['tokens']:
#     print(token)

dict_keys(['premise', 'hypothesis'])
[{'index': 0, 'rawText': 'Roman'}, {'index': 1, 'rawText': 'Atwood'}, {'index': 2, 'rawText': 'is'}, {'index': 3, 'rawText': 'a'}, {'index': 4, 'rawText': 'content'}, {'index': 5, 'rawText': 'creator'}, {'index': 6, 'rawText': '.'}]
[{'tokenIndex': 2, 'verbatlas': {'frameName': 'COPULA', 'roles': [{'role': 'Theme', 'score': 1.0, 'span': [0, 2]}, {'role': 'Attribute', 'score': 1.0, 'span': [3, 6]}]}, 'englishPropbank': {'frameName': 'be.01', 'roles': [{'role': 'ARG1', 'score': 1.0, 'span': [0, 2]}, {'role': 'ARG2', 'score': 1.0, 'span': [3, 6]}]}}]


In [None]:
# WSD exploration

sample_range = 10 #len(fever_dataset['train'])
loop = tqdm(range(sample_range))

pos_info = dict()

for i in loop:

  data =  fever_dataset['train'][i]
  sample_id = data['id']

  hypothesis = data['hypothesis']
  hyp_wsd = data['wsd']['hypothesis']

  for hyp_wsd_dict in hyp_wsd:
    
    pos  = hyp_wsd_dict['pos']
    pos_info[pos] = 1 if pos not in pos_info else pos_info[pos] + 1

pos_info

In [7]:
for annotation in fever_dataset['train'][0]['srl']['premise']['annotations']:
    for key, value in annotation.items():
        print(f"{key}:\t{value}")
    print()

tokenIndex:	4
verbatlas:	{'frameName': 'COPULA', 'roles': [{'role': 'Theme', 'score': 1.0, 'span': [3, 4]}, {'role': 'Attribute', 'score': 1.0, 'span': [5, 22]}]}
englishPropbank:	{'frameName': 'be.01', 'roles': [{'role': 'ARG1', 'score': 1.0, 'span': [3, 4]}, {'role': 'ARG2', 'score': 1.0, 'span': [5, 22]}]}

tokenIndex:	6
verbatlas:	{'frameName': 'KNOW', 'roles': [{'role': 'Theme', 'score': 1.0, 'span': [3, 4]}, {'role': 'Attribute', 'score': 1.0, 'span': [5, 6]}, {'role': 'Topic', 'score': 1.0, 'span': [7, 22]}]}
englishPropbank:	{'frameName': 'know.01', 'roles': [{'role': 'ARG1', 'score': 1.0, 'span': [3, 4]}, {'role': 'ARGM-MNR', 'score': 1.0, 'span': [5, 6]}, {'role': 'ARG2', 'score': 1.0, 'span': [7, 22]}]}

tokenIndex:	13
verbatlas:	{'frameName': 'RECORD', 'roles': [{'role': 'Location', 'score': 1.0, 'span': [8, 10]}, {'role': 'Agent', 'score': 1.0, 'span': [12, 13]}, {'role': 'Theme', 'score': 1.0, 'span': [14, 18]}, {'role': 'Time', 'score': 1.0, 'span': [18, 22]}]}
englishPr

In [8]:
fever_dataset['train'][0:2]['srl']

[{'premise': {'tokens': [{'index': 0, 'rawText': 'Roman'},
    {'index': 1, 'rawText': 'Atwood'},
    {'index': 2, 'rawText': '.'},
    {'index': 3, 'rawText': 'He'},
    {'index': 4, 'rawText': 'is'},
    {'index': 5, 'rawText': 'best'},
    {'index': 6, 'rawText': 'known'},
    {'index': 7, 'rawText': 'for'},
    {'index': 8, 'rawText': 'his'},
    {'index': 9, 'rawText': 'vlogs'},
    {'index': 10, 'rawText': ','},
    {'index': 11, 'rawText': 'where'},
    {'index': 12, 'rawText': 'he'},
    {'index': 13, 'rawText': 'posts'},
    {'index': 14, 'rawText': 'updates'},
    {'index': 15, 'rawText': 'about'},
    {'index': 16, 'rawText': 'his'},
    {'index': 17, 'rawText': 'life'},
    {'index': 18, 'rawText': 'on'},
    {'index': 19, 'rawText': 'a'},
    {'index': 20, 'rawText': 'daily'},
    {'index': 21, 'rawText': 'basis'},
    {'index': 22, 'rawText': '.'},
    {'index': 23, 'rawText': 'His'},
    {'index': 24, 'rawText': 'vlogging'},
    {'index': 25, 'rawText': 'channel'},
    {

In [120]:
#@title SRL exploration: collect frame-names and roles 

def get_srl_info(dataset):

  sample_range = len(dataset)
  loop = tqdm(range(sample_range))

  relevant_info = dict()

  set_of_va_frames = dict() # verb atlas frames
  set_of_pb_frames = dict() # prop bank frames

  verbs_freqs = dict()

  set_of_va_roles = set() # verb atlas frames
  set_of_pb_roles = set() # prop bank frames

  for i in loop:

    data =  dataset[i]
    sample_id = data['id']
    
    tokens = data['srl']['premise']['tokens']
    annotations = data['srl']['premise']['annotations']

    for annotation in annotations:

      token_index = annotation['tokenIndex']
      verb = tokens[token_index]['rawText']
      verbs_freqs[verb] =  1 if verb not in  verbs_freqs else verbs_freqs[verb] + 1

      vb_frame = annotation['verbatlas']['frameName']
      pb_frame = annotation['englishPropbank']['frameName']

      set_of_va_frames[vb_frame] = 1 if vb_frame not in set_of_va_frames else set_of_va_frames[vb_frame] + 1 
      set_of_pb_frames[pb_frame] = 1 if pb_frame not in set_of_pb_frames else set_of_pb_frames[pb_frame] + 1 

      va_roles = annotation['verbatlas']['roles']
      pb_roles = annotation['englishPropbank']['roles']

      for role in va_roles:
        set_of_va_roles.add(role['role'])

      for role in pb_roles:
        set_of_pb_roles.add(role['role'])

  return set_of_va_frames, set_of_pb_frames, set_of_va_roles, set_of_pb_roles, verbs_freqs
  

In [135]:
set_of_va_frames, set_of_pb_frames, set_of_va_roles, set_of_pb_roles, verbs_freqs = get_srl_info(fever_dataset['train'])

verbs_freqs = sorted(verbs_freqs.items(), key=lambda item: item[1], reverse=True)
set_of_pb_frames = sorted(set_of_pb_frames.items(), key=lambda item: item[1], reverse=True)
set_of_va_frames = sorted(set_of_va_frames.items(), key=lambda item: item[1], reverse=True)

print(f"Number of Verb Atlas frames: {len(set_of_va_frames)}")
print(f"Number of Verb Atlas roles: {len(set_of_va_roles)}")

print(f"Number of Propbank frames: {len(set_of_pb_frames)}")
print(f"Number of Propbank roles: {len(set_of_pb_roles)}")

100%|██████████| 51086/51086 [01:13<00:00, 690.78it/s]

Number of Verb Atlas frames: 388
Number of Verb Atlas roles: 30
Number of Propbank frames: 1932
Number of Propbank roles: 40





In [144]:
#@title Save srl info

if save_info:

  with open(info + "va_roles.txt", "w") as va_roles:
    print(f"Saving all Verb Atlas roles of the dataset into {info + 'va_roles.txt'}")
    for elem in set_of_va_roles: va_roles.write(f"{elem}\n")

  with open(info + "va_frames.txt", "w") as va_frames:
    print(f"Saving all Verb Atlas frames of the dataset into {info + 'va_frames.txt'}")
    for elem in set_of_pb_frames: va_frames.write(f"{elem}\n")


  with open(info + "pb_roles.txt", "w") as pb_roles:
    print(f"Saving all Propbank roles of the dataset into {info + 'pb_roles.txt'}")
    for elem in set_of_pb_roles: pb_roles.write(f"{elem}\n")


  with open(info + "pb_frames.txt", "w") as pb_frames:
    print(f"Saving all Propbank frames of the dataset into {info + 'pb_frames.txt'}")
    for elem in set_of_pb_frames: pb_frames.write(f"{elem}\n")

  with open(info + "verbs_freqs.txt", "w") as verbs:
    print(f"Saving all verbs frequencies count into {info + 'verbs_freqs.txt'}")
    for elem in set_of_pb_frames: verbs.write(f"{elem}\n")


Saving all Verb Atlas roles of the dataset into ./info/va_roles.txt
Saving all Verb Atlas frames of the dataset into ./info/va_frames.txt
Saving all Propbank roles of the dataset into ./info/pb_roles.txt
Saving all Propbank frames of the dataset into ./info/pb_frames.txt
Saving all verbs frequencies count into ./info/verbs_freqs.txt



## Augmentation
The following prompt can be used to correct the grammar of a modified hypotesis so if just a not is added the phrase is transformed in negative form:

"Correct the grammar in the following inputs, rephrase the input if necessary to make them more accurate. Provide just the correct version, no explanation."

Ask for:
symmetric relation --> Marry
antisymmetric relation --> Kill

born
sell


In [39]:
#@title Augmentation utils

def join_strings_smartly(words):
    """ Joins a list of words smartly:
    - Adds spaces between words when appropriate.
    - Avoids adding spaces before punctuation.
    """
    punctuation = {'.', ',', ';', ':', '!', '?',')'}
    result = words[0]
    prev = result

    for word in words[1:]:
      if word in punctuation or \
        ("'" in prev) or \
        word.startswith("'") or \
        ("." in prev and "." in word) or \
        ("(" in prev) :
        # add word without space before
        result += word
      else:
        # add with space before
        result += " " + word
      # keep track of previous word  
      prev = word

    return result

def get_synset_from_id(synset_id):
    if synset_id == 'O':
        return None
    try:
        offset = int(''.join(filter(str.isdigit, synset_id)))
        pos = synset_id[-1]
        synset = wn.synset_from_pos_and_offset(pos, offset)
        return synset
    except:
        print("exception")
        return None
    

def get_related_word(synset, pos): 

    info = dict()

    # Map POS tags to WordNet POS tags
    pos_map = {
        'NOUN': wn.NOUN,
        'VERB': wn.VERB,
        'ADJ': wn.ADJ,
        'ADV': wn.ADV
    }
    
    if pos not in pos_map:
        return None 
    
    # get hypernyms
    hypernyms = synset.hypernyms()
    if not hypernyms:
        return None
    hypernyms = synset.hypernyms()
    hypernym_words = set()
    for hypernym in hypernyms:
        hypernym_words.update(hypernym.lemma_names())
    info['hypernyms'] = list(hypernym_words)

    # get synonyms 
    synonyms = synset.lemmas()
    if not synonyms:
        return None
    info['synonyms'] = [synonym.name() for synonym in synonyms if synonym.name() != synset.lemmas()[0].name()]

    return info

In [81]:
#@title Augmentation by Synonyms and Hypernyms

sample_range = 10 #len(fever_dataset['train'])
loop = tqdm(range(sample_range))

new_samples = {
    'id': [],
    'premise': [],
    'hypothesis': [],
    'label': [],
    'wsd': [None for i in range(sample_range)],
    'srl': [None for i in range(sample_range)]
}

syn_dict = dict()
syn_idx = 0
progressive_id = int(max_id)

for i in loop:

  data =  fever_dataset['train'][i]
  sample_id = data['id']

  premise = data['premise']
  label = data['label']
  hyp_wsd = data['wsd']['hypothesis']

  new_samples['id'].append(str(progressive_id))
  new_samples['premise'].append(premise)
  new_samples['label'].append(label)
  progressive_id += 1

  hypothesis = data['hypothesis']
  new_hypotesys = []

  for hyp_wsd_dict in hyp_wsd:

    word = hyp_wsd_dict['text']
    pos = hyp_wsd_dict['pos']
    offset = hyp_wsd_dict['wnSynsetOffset']
    synset = get_synset_from_id(offset)

    related_words = None
    if synset: related_words = get_related_word(synset, pos)

    if related_words:
      hypernyms = related_words['hypernyms']
      synonyms = related_words['synonyms']

      if not synonyms : 
        new_hypotesys.append(word)
        continue

      synonym = synonyms[syn_idx % (len(synonyms))]

      if synonym in syn_dict and syn_dict[synonym] > 10:
        syn_idx += 1
        synonym = synonyms[syn_idx % (len(synonyms))]

      syn_dict[synonym] = 1 if  synonym not in syn_dict else syn_dict[synonym] + 1
      # if synonym in syn_dict 

      syn_idx += 1

      if '_' in synonym: synonym = synonym.replace('_', ' ')
      new_hypotesys.append(synonym)

    else: new_hypotesys.append(word)


  new_hypotesys = join_strings_smartly(new_hypotesys)
  new_samples['hypothesis'].append(new_hypotesys)
  # print(f"new: {new_hypotesys} \n")

# print(new_samples)

100%|██████████| 10/10 [00:00<00:00, 301.98it/s]


In [85]:
#@title Add new samples to the original dataset
 
augmentation = Dataset.from_dict(new_samples)
print(f"Augmentation type: {type(augmentation)}")

augmented_dataset = concatenate_datasets([fever_dataset['train'], augmentation])
print(f"Augmented dataset type: {type(augmented_dataset)}")
print(f"Train split length: {len(fever_dataset['train'])}")
print(f"Train split augmentated: {len(augmented_dataset)}")

Augmentation type: <class 'datasets.arrow_dataset.Dataset'>
Augmented dataset type: <class 'datasets.arrow_dataset.Dataset'>
Train split length: 51086
Train split augmentated: 51096


In [91]:
sample_range = 10 
# sample_range = len(fever_dataset['train'])

locations = set()
# loop = tqdm(range(sample_range))

# for i in loop:
for i in range(sample_range):

    data = fever_dataset['train'][i]

    # id
    id =  data['id']

    # premise
    sentence = data['premise']
    hypothesis = data['hypothesis']

    # wsd
    wsd =  data['wsd']

    # srl
    possible_locations = []
    tokens = data['srl']['premise']['tokens']
    annotations = data['srl']['premise']['annotations']

    # print(tokens)
    # print(annotations)

    for annotation in annotations:
        # print(annotation['verbatlas']['roles'])
        token_idx = annotation['tokenIndex'] 
        roles = annotation['verbatlas']['roles']
        # index = annotation['verbatlas']

        location = False
        agent = False

        # print(index)
        text_location = ""
        for role in roles:
            if role['role'] == 'Location':
                span = role['span']

                for item in tokens:
                    if span[0] <= item['index'] <= span[1]:
                        text_location += item['rawText'] + " "

                # filtered_texts = [item['rawText'] for item in tokens if span[0] <= item['index'] <= span[1]]
                # text_location = " ".join(filtered_texts)
                # filtered_words = [word for word in filtered_texts if any(char.isupper() for char in word)]

                if text_location:
                    # possible_locations.append(text_location)
                    # print(f"sample id: {id}")
                    # print(f"text: {text_location}")

                    entities =  nlp_spacy(text_location).ents
                    # print(f"entities: {entities}")
                    if entities:
                        found_target = False
                        for ent in entities:
                            if ent.label_ == "GPE":
                                found_target = True
                                locations.add(ent.text)

                                # print info related to location
                                print(sentence)
                                print(tokens)
                                print(annotation['verbatlas'])
                                print(hypothesis)
                                # print(f"target entity: {ent.text}")
                            # else:
                            #     print("no target ")
                        # if not found_target :  print("no location target ")
                    else:
                        # print("no target ")
                        ...
                    
                    # print("----------------------------------------- \n")
    

John Wick: Chapter 2 . The film premiered in Los Angeles on January 30 , 2017 , and was theatrically released in the United States on February 10 , 2017 .
[{'index': 0, 'rawText': 'John'}, {'index': 1, 'rawText': 'Wick'}, {'index': 2, 'rawText': ':'}, {'index': 3, 'rawText': 'Chapter'}, {'index': 4, 'rawText': '2'}, {'index': 5, 'rawText': '.'}, {'index': 6, 'rawText': 'The'}, {'index': 7, 'rawText': 'film'}, {'index': 8, 'rawText': 'premiered'}, {'index': 9, 'rawText': 'in'}, {'index': 10, 'rawText': 'Los'}, {'index': 11, 'rawText': 'Angeles'}, {'index': 12, 'rawText': 'on'}, {'index': 13, 'rawText': 'January'}, {'index': 14, 'rawText': '30'}, {'index': 15, 'rawText': ','}, {'index': 16, 'rawText': '2017'}, {'index': 17, 'rawText': ','}, {'index': 18, 'rawText': 'and'}, {'index': 19, 'rawText': 'was'}, {'index': 20, 'rawText': 'theatrically'}, {'index': 21, 'rawText': 'released'}, {'index': 22, 'rawText': 'in'}, {'index': 23, 'rawText': 'the'}, {'index': 24, 'rawText': 'United'}, {'in

In [None]:
# add new samples to the dataset
filtered_dataset = fever_dataset['train'].filter(lambda example: 'born' in example['premise'].lower() or 'born' in example['hypothesis'].lower())