In [None]:
# Import Statements
import re
import json
import torch
import pickle
import random
import adapters
import datasets
import evaluate
import accelerate
import numpy as np
import transformers
import pandas as pd
from tqdm.auto import tqdm
from fuzzywuzzy import fuzz
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BartTokenizer, BartForSequenceClassification, DataCollatorWithPadding, DataCollatorForSeq2Seq
from transformers import AdamW, Seq2SeqTrainer, Seq2SeqTrainingArguments, get_scheduler, Trainer, TrainingArguments, GenerationConfig, set_seed

In [None]:
# Random Seed Function
def set_logging_and_seed(seed=42):
    # Logging output settings
    transformers.utils.logging.set_verbosity_info()
    logger = transformers.utils.logging.get_logger("transformers")
    transformers.utils.logging.set_verbosity(30)
    logger.warning("WARN")

    # Random seet outputs
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if (torch.cuda.is_available()): torch.cuda.manual_seed_all(seed)
    set_seed(seed)
    return

In [None]:
# Test CUDA Availability
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
!nvidia-smi

In [None]:
# Load Model/Tokenizer
set_logging_and_seed()

model_name = "mistralai/Mistral-7B-v0.1"

# Select the configuration
if (model_name == "mistralai/Mistral-7B-v0.1"):
    final_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    
    # Freeze classification head
    freeze = True
    if (freeze == True):
        for param in final_model.model.parameters(): param.requires_grad = False
        pass
    
    # Load tokenizer
    injected_tokenizer = AutoTokenizer.from_pretrained("mistral-tokenizer")
    
    # Add explicit pad token for Mistral
    if injected_tokenizer.pad_token is None:
        injected_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        final_model.resize_token_embeddings(len(injected_tokenizer))

    # Ensure Mistral has a valid pad token ID
    final_model.config.pad_token_id = injected_tokenizer.pad_token_id
    
elif (model_name != "pier"):
    final_model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2)

    # Load tokenizer
    injected_tokenizer = BartTokenizer.from_pretrained("bart-large-tokenizer")
else:
    final_model = BartForSequenceClassification.from_pretrained("facebook/bart-base", num_labels=2)
    
    # Initialize architecture for PIER+, then load in the weights
    adapters.init(final_model)
    final_model.load_adapter("~/path/")
    final_model.set_active_adapters("non-compositional")
    final_model.load_adapter("~/path/")
    final_model.set_active_adapters("compositional")
    adapter_fusion_name = final_model.load_adapter_fusion("~/path/")
    final_model.set_active_adapters(adapter_fusion_name)
    final_model.train_adapter("compositional")
    final_model.train_adapter("non-compositional")
    
    # Load tokenizer
    injected_tokenizer = BartTokenizer.from_pretrained("bart-tokenizer")

In [None]:
# Load FLUTE
flute_train = pd.read_json(path_or_buf="~/path/", lines=True).values.tolist() # 1768
flute_test = pd.read_json(path_or_buf="~/path/", lines=True).values.tolist() # 250

flute_train_final = []
flute_test_final = []

for i in range(len(flute_train)):
    sample = flute_train[i]
    if (sample[-2] == "Contradiction"): label = [0.0, 1.0]
    else: label = [1.0, 0.0]
    flute_train_final.append([sample[1], sample[2], label])
    pass

for i in range(len(flute_test)):
    sample = flute_test[i]
    if (sample[-2] == "Contradiction"): label = [0.0, 1.0]
    else: label = [1.0, 0.0]
    flute_test_final.append([sample[1], sample[2], label])
    pass

In [None]:
# Load IMPLI
set_logging_and_seed()

magpie_e = pd.read_csv("~/path/", sep="\t").values.tolist()
magpie_ne = pd.read_csv("~/path/", sep="\t").values.tolist()
magpie_adversarial_ne = pd.read_csv("~/path/", sep="\t").values.tolist()
pie_e = pd.read_csv("~/path/", sep="\t").values.tolist()
pie_ne = pd.read_csv("~/path/", sep="\t").values.tolist()
pie_adversarial_ne = pd.read_csv("~/path/", sep="\t").values.tolist()
semeval_e = pd.read_csv("~/path/", sep="\t").values.tolist()
semeval_ne = pd.read_csv("~/path/", sep="\t").values.tolist()
semeval_adversarial_ne = pd.read_csv("~/path/", sep="\t").values.tolist()
manual_e = pd.read_csv("~/path/", sep="\t", header=None).values.tolist()
manual_ne = pd.read_csv("~/path/", sep="\t", header=None).values.tolist()
manual_antonyms_ne = pd.read_csv("~/path/", sep="\t", header=None).values.tolist()

def label_cleaning(e, ne, adversarial_ne):
    for i in range(len(e)): e[i] = [e[i][0], e[i][1], [1.0, 0.0]]
    for i in range(len(ne)): ne[i] = [ne[i][0], ne[i][1], [0.0, 1.0]]
    for i in range(len(adversarial_ne)): adversarial_ne[i] = [adversarial_ne[i][0], adversarial_ne[i][1], [0.0, 1.0]]
    return e, ne, adversarial_ne

magpie_e, magpie_ne, magpie_adversarial_ne = label_cleaning(magpie_e, magpie_ne, magpie_adversarial_ne)
pie_e, pie_ne, pie_adversarial_ne = label_cleaning(pie_e, pie_ne, pie_adversarial_ne)
semeval_e, semeval_ne, semeval_adversarial_ne = label_cleaning(semeval_e, semeval_ne, semeval_adversarial_ne)
manual_e, manual_ne, manual_antonyms_ne = label_cleaning(manual_e, manual_ne, manual_antonyms_ne)

def length_indicators(data, max_len=64, min_len=None, invalid_labels=[]):   
    premi = []
    hypothi = []

    for i in range(len(data)):
        prem_len, hypoth_len = data[i][0], data[i][1]
        prem_len = len(prem_len.split(" "))
        hypoth_len = len(hypoth_len.split(" "))

        if (prem_len >= max_len or (min_len != None and prem_len <= min_len)):
            premi.append(i)
            
        if (hypoth_len >= max_len or (min_len != None and hypoth_len <= min_len)):
            hypothi.append(i)
            
        if (magpie_t[i][2] in invalid_labels):
            premi.append(i)

    invalid_indices = list(set(premi + hypothi))

    for index in sorted(invalid_indices, reverse=True):
        del data[index]
    
    return data

mlin = None # Toggle
magpie_e = length_indicators(magpie_e, min_len=mlin, verbose=verbosity, invalid_labels=[])
magpie_ne = length_indicators(magpie_ne, min_len=mlin, verbose=verbosity, invalid_labels=[])
magpie_adversarial_ne = length_indicators(magpie_adversarial_ne, min_len=mlin, verbose=verbosity, invalid_labels=[])
pie_e = length_indicators(pie_e, min_len=mlin, verbose=verbosity, invalid_labels=[])
pie_ne = length_indicators(pie_ne, min_len=mlin, verbose=verbosity, invalid_labels=[])
pie_adversarial_ne = length_indicators(pie_adversarial_ne, min_len=mlin, verbose=verbosity, invalid_labels=[])
semeval_e = length_indicators(semeval_e, min_len=mlin, verbose=verbosity, invalid_labels=[])
semeval_ne = length_indicators(semeval_ne, min_len=mlin, verbose=verbosity, invalid_labels=[])
semeval_adversarial_ne = length_indicators(semeval_adversarial_ne, min_len=mlin, verbose=verbosity, invalid_labels=[])

In [None]:
# Load FigurativeNarrativeBenchmark
set_logging_and_seed()

fnb_train = pd.read_json(path_or_buf="~/path/", lines=True).values.tolist()
with open("~/path/", "rb") as file: fnb_test = pickle.load(file)

In [None]:
# Manually Define Text Correction (Toggle Whether to Use)
def reformat_text(text):
    if (text == ""): return text
    
    text = text.replace("  </s>", ". </s>")
    text = text.replace("_", " ") 
    text = text.replace("  ", ", ") 
    text = text.replace(" i ", " I ")
    text = text.replace(" , ", ", ") 
    text = text.replace(" - ", "-") 
    text = text.replace("We 'll", "We'll") 
    text = text.replace(",I", ", I") 
    text = text.replace("I 'd", "I'd")
    text = text.replace("I' ll", "I'll") 
    text = text.replace("you 'll", "you'll") 
    text = text.replace("ca n't", "can't") 
    text = text.replace("they 'd", "they'd") 
    text = text.replace("( ", "(")
    text = text.replace(" )", ")")
    text = text.replace(" ; ", "; ")
    text = text.replace("Let 's", "Let's")
    text = text.replace("‘ ", "")
    text = text.replace("-", " ")
    text = text.replace(" 's", "'s")
    text = text.replace(" 'll", "'ll")
    text = text.replace("..", ".")
    text = text.replace(" . ", ".")
    text = text.replace(" .", ".")    
    text = text.replace(" ?", "?")
    
    # Capitalise sentence
    if (len(text) > 1): text = text[0].upper() + text[1:] 
    else: text = text.upper()
    text = text.strip()
    
    # Add punctuation
    if (text[-1] != "." and text[-1] != "!" and text[-1] != "?"): text = text + "."
        
    return text

In [None]:
# FigurativeNarrativeBenchmark Auxiliary Methods
set_logging_and_seed()

def fnb_remove_contexts(premise, amount_keep=None, total_removal=False, remove_idiom=False):
    if (total_removal == True):
        if ("<b>" not in premise): return premise.split(". ")[-1] # The last sentence, which usually contains the IE
        text = premise.split("<b>")
        text = text[1]
        text = text.split("</b>")
        return text[0]
    
    # Perform a percentage removal of words starting from the beginning
    if (amount_keep != None):
        all_words = premise.split(" ")
        keep = int(amount_keep * len(all_words))
        text = " ".join(all_words[-keep:])
        return text
    
    # Default: remove all sentences except for the one containing the idiom
    text = premise.split(". ")
    idx = -1
    
    # Determine which sentence the idiom is in
    for i in range(len(text)):
        if ("<b>" in text[i]): idx = i
        pass
        
    return text[idx]

def fnb_shuffling(premise):
    text = premise.split(" ")
    
    # Get the idiom's range of indices
    start = 0
    end = -1
    
    for i in range(len(text)):
        if (("<b>" in text[i] or "b>" in text[i] or "<b" in text[i]) and (start == 0)): start = i
        if (("</b>" in text[i] or "/b>" in text[i] or "</b" in text[i]) and (end == -1)): end = i
        pass
    
    if (end != len(text) - 1): end += 1
    
    start_portion = text[:start]
    end_portion = text[end:]
    
    # Shuffle both portions
    random.shuffle(start_portion)
    if (len(end_portion) > 0): random.shuffle(end_portion)
    
    # Re-append the two portions
    shuffled_text = start_portion + text[start:end] + end_portion
    return (" ".join(shuffled_text)).strip()

def fnb_random_removal(premise, amount_keep=0.9):
    # Get number of words
    text = premise.split(" ")
    num_words = len(text)
    num_keep = int(num_words * amount_keep)
    num_remove = num_words - num_keep
    
    # Get the location of the idiom
    start = 0
    end = -1
    for i in range(len(text)):
        if (("<b>" in text[i] or "b>" in text[i] or "<b" in text[i]) and (start == 0)): start = i
        
        if (("</b>" in text[i] or "/b>" in text[i] or "</b" in text[i]) and (end == -1)): end = i
        pass
    
    if (end != len(text) - 1): end += 1
        
    # Get the list of indices to remove
    idx = []
    for i in range(num_remove):
        removal_idx = np.random.randint(num_words)
        while (removal_idx >= start and removal_idx < end):
            removal_idx = np.random.randint(num_words)
        idx.append(removal_idx)
        pass
    
    text = [j for i, j in enumerate(text) if i not in idx]
    return " ".join(text).strip()

def fnb_retrieve_real_idiom(premise):
    if ("<b>" not in premise): return premise.split(". ")[-1] # The last sentence, which usually contains the IE
    text = premise.split("<b>")
    text = text[1]
    text = text.split("</b>")
    return text[0]

def generate_gibberish(cutoff=15, s="abcdefghijklmnopqrstuvwxyz"):
    length = np.random.randint(cutoff)
    
    text = ""
    for i in range(length): text += random.choice(s)
    return text

# Uncomment below if we replace the IE with a randomly generated string
# '''
for i in range(len(fnb_test)):
    idiom_temp = fnb_retrieve_real_idiom(fnb_test[i][0])
    gibb = generate_gibberish()
    fnb_test[i][0] = fnb_test[i][0].replace(idiom_temp, gibb)
    pass
# '''

def fnb_split(fnb_train, fnb_test, full_continuation=False, amount_keep=None, total_removal=False, remove_idiom=False, original=False, no_removal=False, shuffle=False, random_removal=False, move_idiom=False):
    train_fnb = {"premises" : [], "hypotheses" : [], "labels" : []}
    
    # Add training samples
    for i in range(len(fnb_train)):
        sample = fnb_train[i]
        truth = sample[5]
        
        # Append the options as follows
        if (truth == "option1"):
            train_fnb["premises"].append(reformat_text(sample[0]))
            train_fnb["hypotheses"].append(reformat_text(sample[3]))
            train_fnb["labels"].append([1.0, 0.0])
            
            # Append the non-entailment converse sample
            train_fnb["premises"].append(reformat_text(sample[0]))
            train_fnb["hypotheses"].append(reformat_text(sample[4]))
            train_fnb["labels"].append([0.0, 1.0])
        else:
            train_fnb["premises"].append(reformat_text(sample[0]))
            train_fnb["hypotheses"].append(reformat_text(sample[3]))
            train_fnb["labels"].append([0.0, 1.0])
            
            # Append the non-entailment converse sample
            train_fnb["premises"].append(reformat_text(sample[0]))
            train_fnb["hypotheses"].append(reformat_text(sample[4]))
            train_fnb["labels"].append([1.0, 0.0])
        pass
    
    # Now do the same for the test split
    test_fnb = {"premises" : [], "hypotheses" : [], "labels" : []}
    test_e = {"premises" : [], "hypotheses" : [], "labels" : []}
    test_ne = {"premises" : [], "hypotheses" : [], "labels" : []}
    
    for i in range(len(fnb_test)):
        sample = fnb_test[i]
        truth = sample[5]
        premise = fnb_remove_contexts(sample[0], amount_keep=amount_keep, total_removal=total_removal, remove_idiom=remove_idiom)
        
        if (no_removal == True):
            premise = sample[0]
            if (shuffle == True): premise = fnb_shuffling(premise)
        if (random_removal == True): premise = fnb_random_removal(premise, amount_keep=amount_keep)
            
        premise = reformat_text(premise)
        
        # Append the options as follows
        if (truth == "option1"):
            test_fnb["premises"].append(premise)
            test_fnb["hypotheses"].append(reformat_text(sample[3]))
            test_fnb["labels"].append([1.0, 0.0])
            
            test_e["premises"].append(premise)
            test_e["hypotheses"].append(reformat_text(sample[3]))
            test_e["labels"].append([1.0, 0.0])
            
            test_fnb["premises"].append(premise)
            test_fnb["hypotheses"].append(reformat_text(sample[4]))
            test_fnb["labels"].append([0.0, 1.0])
            
            test_ne["premises"].append(premise)
            test_ne["hypotheses"].append(reformat_text(sample[4]))
            test_ne["labels"].append([0.0, 1.0])
        else:
            test_fnb["premises"].append(premise)
            test_fnb["hypotheses"].append(reformat_text(sample[3]))
            test_fnb["labels"].append([0.0, 1.0])
            
            test_ne["premises"].append(premise)
            test_ne["hypotheses"].append(reformat_text(sample[3]))
            test_ne["labels"].append([0.0, 1.0])
            
            test_fnb["premises"].append(premise)
            test_fnb["hypotheses"].append(reformat_text(sample[4]))
            test_fnb["labels"].append([1.0, 0.0])
            
            test_e["premises"].append(premise)
            test_e["hypotheses"].append(reformat_text(sample[4]))
            test_e["labels"].append([1.0, 0.0])        
        
    # Convert all the data to Datasets
    train_fnb = Dataset.from_dict(train_fnb).shuffle(seed=42)
    test_fnb = Dataset.from_dict(test_fnb)
    test_e = Dataset.from_dict(test_e)
    test_ne = Dataset.from_dict(test_ne)
    test_antonyms = Dataset.from_dict({"premises" : [], "hypotheses" : [], "labels" : []})
    
    return train_fnb, test_fnb, test_e, test_ne, test_antonyms

train_fnb_dataset, test_fnb_dataset, test_fnb_manual_e, test_fnb_manual_ne, test_fnb_manual_antonyms_ne = fnb_split(fnb_train, 
                                                                                                                    fnb_test, 
                                                                                                                    full_continuation=False,
                                                                                                                    amount_keep=None,
                                                                                                                    total_removal=False,
                                                                                                                    no_removal=False,
                                                                                                                    shuffle=False,
                                                                                                                    random_removal=False)


def tokenize_function(examples):
    premises = examples["premises"]
    hypotheses = examples["hypotheses"]
    
    return injected_tokenizer(premises, hypotheses, truncation=True)

encoded_fnb_train = train_fnb_dataset.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_fnb_test = test_fnb_dataset.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])

encoded_fnb_manual_e = test_fnb_manual_e.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_fnb_manual_ne = test_fnb_manual_ne.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
data_collator = DataCollatorWithPadding(tokenizer=injected_tokenizer, max_length=256)

In [None]:
# IMPLI Auxiliary Methods
set_logging_and_seed()

def remove_contexts(t1, t2):
    # Forward pass
    s1 = t1.split(" ")
    s2 = t2.split(" ")
    min_len = min(len(s1), len(s2))
    l = -1 # Left cutoff
    
    for i in range(min_len):
        if (s1[i].lower() != s2[i].lower()):
            l = i
            break
        pass
    
    # Cutoff the differing portion
    s1 = s1[l:]
    s2 = s2[l:]
        
    # Backwards pass
    min_len = min(len(s1), len(s2))
    r = -10
    
    for i in range(1, min_len):
        if (s1[-i].lower() != s2[-i].lower()):
            r = -i
            break
        pass
    if (min_len == 1): r = -1
        
    # If we reached the end of the for loop with no differences, keep at least one word
    if (r == -10):
        r = min_len
        s1 = s1[:-r + 1]
        s2 = s2[:-r + 1]
    
    # If we have a normal cutoff then just use that of course
    elif (r != -1):        
        s1 = s1[:r+1]
        s2 = s2[:r+1]
    
    return " ".join(s1), " ".join(s2)

# Retrieve dataframe of IMPLI dataset and corresponding model predictions, as well as other data
with open("~/path/", "rb") as file:
    impli_df = pickle.load(file)
    
def impli_shuffling(premise, idx):
    idiom = impli_df.iloc[idx]["idiom"].split(" ")
    text = premise.split(" ")
    completed = False
    ii = 0 # The idiom traversal index
    start = 0
    end = -1
    
    for i in range(len(text)):
        score = fuzz.ratio(text[i].lower(), idiom[ii].lower())
        
        if (score < 75): continue
        start = i
        
        if (len(idiom) == 1): 
            end = i + 1
            completed = True
            break
            
        for j in range(i+1, len(text)):
            score = fuzz.ratio(text[j].lower(), idiom[j-i].lower())
            if (score < 75): break
            
            if (j - i == len(idiom) - 1):
                completed = True
                end = j + 1
                break
            pass
        
        if (completed == True): break
        pass
    
    # If not found, just return original text
    if (start == 0 and end == -1): return premise
    
    # Split text into idiom and front/back sections
    start_portion = text[:start]
    end_portion = text[end:]
    random.shuffle(start_portion)
    random.shuffle(end_portion)
    text = start_portion + text[start:end] + end_portion
    return " ".join(text).strip()

# Uncomment below if we wish to replace the IEs with randomly generated strings
# '''
agg = manual_e + manual_ne + manual_antonyms_ne

def generate_gibberish(cutoff=15, s="abcdefghijklmnopqrstuvwxyz"):
    length = np.random.randint(cutoff)
    text = ""
    for i in range(length): text += random.choice(s)
    return text

for i in range(len(agg)):
    idiom_temp = remove_contexts(agg[i][0], agg[i][1])[0]
    gibb = generate_gibberish()
    agg[i][0] = agg[i][0].replace(idiom_temp, gibb)
    pass
        
# '''

def train_test_split(manual_e, manual_ne, manual_antonyms_ne, magpie_e, magpie_adversarial_ne, magpie_ne, sample=True, no_context=True, shuffling=True):
    # Balance the dataset labels if needed
    if (sample):
        ne_size = len(magpie_ne) + len(magpie_adversarial_ne)
        constant = 1.0
        idx = random.sample(range(len(magpie_e)), int(constant * ne_size))
        
        magpie_e1 = [magpie_e[i] for i in idx]
        aggregate_train_dataset = magpie_e1 + magpie_ne + magpie_adversarial_ne
    else:
        aggregate_train_dataset = magpie_e + magpie_ne + magpie_adversarial_ne

    # Construct IMPLI training dataset
    impli_train_dataset = Dataset.from_dict({"premises" : [reformat_text(s[0]) for s in aggregate_train_dataset],
                                             "hypotheses" : [reformat_text(s[1]) for s in aggregate_train_dataset],
                                             "labels" : [s[2] for s in aggregate_train_dataset]}).shuffle(seed=42)
    
    # Construct IMPLI test dataset
    if (no_context == True):
        aggregate_test_dataset = manual_e + manual_ne + manual_antonyms_ne

        # Strip the contexts
        impli_test_dataset = {"premises" : [], "hypotheses" : [], "labels" : []}

        for s in aggregate_test_dataset:
            # Remove the contexts from the samples
            t1, t2 = remove_contexts(reformat_text(s[0]), reformat_text(s[1]))
            impli_test_dataset["premises"].append(t1)
            impli_test_dataset["hypotheses"].append(t2)
            impli_test_dataset["labels"].append(s[2])
            pass

        impli_test_dataset = Dataset.from_dict(impli_test_dataset)

        # Construct individual test datasets
        impli_test_manual_e = {"premises" : [], "hypotheses" : [], "labels" : []}
        impli_test_manual_ne = {"premises" : [], "hypotheses" : [], "labels" : []}
        impli_test_manual_antonyms_ne = {"premises" : [], "hypotheses" : [], "labels" : []}

        # Remove contexts from entailment samples
        for s in manual_e:
            t1, t2 = remove_contexts(reformat_text(s[0]), reformat_text(s[1]))
            impli_test_manual_e["premises"].append(t1)
            impli_test_manual_e["hypotheses"].append(t2)
            impli_test_manual_e["labels"].append(s[2])
            pass

        # Remove contexts from non-entailment samples
        for s in manual_ne:
            t1, t2 = remove_contexts(reformat_text(s[0]), reformat_text(s[1]))
            impli_test_manual_ne["premises"].append(t1)
            impli_test_manual_ne["hypotheses"].append(t2)
            impli_test_manual_ne["labels"].append(s[2])
            pass

        # Remove contexts from antonym samples
        for s in manual_antonyms_ne:
            t1, t2 = remove_contexts(reformat_text(s[0]), reformat_text(s[1]))
            impli_test_manual_antonyms_ne["premises"].append(t1)
            impli_test_manual_antonyms_ne["hypotheses"].append(t2)
            impli_test_manual_antonyms_ne["labels"].append(s[2])
            pass

        impli_test_manual_e = Dataset.from_dict(impli_test_manual_e)
        impli_test_manual_ne = Dataset.from_dict(impli_test_manual_ne)
        impli_test_manual_antonyms_ne = Dataset.from_dict(impli_test_manual_antonyms_ne)
    
    if (no_context == False):
        # Construct IMPLI test dataset
        aggregate_test_dataset = manual_e + manual_ne + manual_antonyms_ne
        
        if (shuffling == True):
            impli_test_dataset = Dataset.from_dict({"premises" : [impli_shuffling(reformat_text(s[0]), aggregate_test_dataset.index(s)) for s in aggregate_test_dataset],
                                         "hypotheses" : [reformat_text(s[1]) for s in aggregate_test_dataset],
                                         "labels" : [s[2] for s in aggregate_test_dataset]})

            # Construct individual test datasets
            impli_test_manual_e = Dataset.from_dict({"premises" : [impli_shuffling(reformat_text(s[0]), manual_e.index(s)) for s in manual_e],
                                                     "hypotheses" : [reformat_text(s[1]) for s in manual_e],
                                                     "labels" : [s[2] for s in manual_e]})
            impli_test_manual_ne = Dataset.from_dict({"premises" : [impli_shuffling(reformat_text(s[0]), manual_ne.index(s)) for s in manual_ne],
                                                     "hypotheses" : [reformat_text(s[1]) for s in manual_ne],
                                                     "labels" : [s[2] for s in manual_ne]})
            impli_test_manual_antonyms_ne = Dataset.from_dict({"premises" : [impli_shuffling(reformat_text(s[0]), manual_antonyms_ne.index(s)) for s in manual_antonyms_ne],
                                                               "hypotheses" : [reformat_text(s[1]) for s in manual_antonyms_ne],
                                                               "labels" : [s[2] for s in manual_antonyms_ne]})
        else:
            impli_test_dataset = Dataset.from_dict({"premises" : [reformat_text(s[0]) for s in aggregate_test_dataset],
                                                     "hypotheses" : [reformat_text(s[1]) for s in aggregate_test_dataset],
                                                     "labels" : [s[2] for s in aggregate_test_dataset]})

            # Construct individual test datasets
            impli_test_manual_e = Dataset.from_dict({"premises" : [reformat_text(s[0]) for s in manual_e],
                                                     "hypotheses" : [reformat_text(s[1]) for s in manual_e],
                                                     "labels" : [s[2] for s in manual_e]})
            impli_test_manual_ne = Dataset.from_dict({"premises" : [reformat_text(s[0]) for s in manual_ne],
                                                     "hypotheses" : [reformat_text(s[1]) for s in manual_ne],
                                                     "labels" : [s[2] for s in manual_ne]})
            impli_test_manual_antonyms_ne = Dataset.from_dict({"premises" : [reformat_text(s[0]) for s in manual_antonyms_ne],
                                                               "hypotheses" : [reformat_text(s[1]) for s in manual_antonyms_ne],
                                                               "labels" : [s[2] for s in manual_antonyms_ne]})
    
    return impli_train_dataset, impli_test_dataset, impli_test_manual_e, impli_test_manual_ne, impli_test_manual_antonyms_ne

train_dataset, test_dataset, test_manual_e, test_manual_ne, test_manual_antonyms_ne = train_test_split(manual_e, 
                                                                                                       manual_ne, 
                                                                                                       manual_antonyms_ne, 
                                                                                                       magpie_e + pie_e + semeval_e, 
                                                                                                       magpie_adversarial_ne + pie_adversarial_ne + semeval_adversarial_ne, 
                                                                                                       magpie_ne + pie_ne + semeval_ne,
                                                                                                       no_context=False,
                                                                                                       shuffling=False)                                                                                                 

def tokenize_function(examples):
    premises = examples["premises"]
    hypotheses = examples["hypotheses"]
    
    return injected_tokenizer(premises, hypotheses, truncation=True)

encoded_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_test = test_dataset.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])

# Tokenization for individual datasets
encoded_manual_e = test_manual_e.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_manual_ne = test_manual_ne.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_manual_antonyms_ne = test_manual_antonyms_ne.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])

data_collator = DataCollatorWithPadding(tokenizer=injected_tokenizer, max_length=256)

In [None]:
# FLUTE Auxiliary Methods
set_logging_and_seed()

def flute_construction(train, test, no_context=False):
    flute_train = {"premises" : [reformat_text(s[0]) for s in train], 
                   "hypotheses" : [reformat_text(s[1]) for s in train], 
                   "labels" : [s[2] for s in train]}
    flute_train = Dataset.from_dict(flute_train).shuffle(seed=42)
    
    flute_test = {"premises" : [], "hypotheses" : [], "labels" : []}
    flute_test_e = {"premises" : [], "hypotheses" : [], "labels" : []}
    flute_test_ne = {"premises" : [], "hypotheses" : [], "labels" : []}
    
    for i in range(len(test)):
        s = test[i]
        if (no_context == True):
            premise, hypothesis = remove_contexts(s[0], s[1])
        else:
            premise = s[0]
            hypothesis = s[1]
        
        premise = reformat_text(premise)
        hypothesis = reformat_text(hypothesis)
        
        flute_test["premises"].append(premise)
        flute_test["hypotheses"].append(hypothesis)
        flute_test["labels"].append(s[2])
        
        if (s[2] == [0.0, 1.0]):
            flute_test_ne["premises"].append(premise)
            flute_test_ne["hypotheses"].append(hypothesis)
            flute_test_ne["labels"].append(s[2])
        else:
            flute_test_e["premises"].append(premise)
            flute_test_e["hypotheses"].append(hypothesis)
            flute_test_e["labels"].append(s[2])
        pass
    
    flute_test = Dataset.from_dict(flute_test)
    flute_test_e = Dataset.from_dict(flute_test_e)
    flute_test_ne = Dataset.from_dict(flute_test_ne)
    
    return flute_train, flute_test, flute_test_e, flute_test_ne

flute_train, flute_test, flute_test_e, flute_test_ne = flute_construction(flute_train_final, 
                                                                          flute_test_final, 
                                                                          no_context=False)

def tokenize_function(examples):
    premises = examples["premises"]
    hypotheses = examples["hypotheses"]
    
    return injected_tokenizer(premises, hypotheses, truncation=True)

encoded_flute_train = flute_train.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_flute_test = flute_test.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])

# Tokenization for individual datasets
encoded_flute_e = flute_test_e.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
encoded_flute_ne = flute_test_ne.map(tokenize_function, batched=True, remove_columns=["premises", "hypotheses"])
data_collator = DataCollatorWithPadding(tokenizer=injected_tokenizer, max_length=256)

In [None]:
# Trainer Auxiliary Methods
set_logging_and_seed()

def compute_metrics(eval_preds):
    metric = evaluate.load("accuracy")
    logits, labels = eval_preds
    if isinstance(logits, tuple): logits = logits[0]
    predictions = np.argmax(logits, axis=-1)
    labels = [abs(label[0] - 1) for label in labels]
    return metric.compute(predictions=list(predictions), references=list(labels))

# Uncomment below for hyperparameter search
# '''
def model_init(trial): return final_model

trainer = Trainer(model=None,
                  args=training_args,
                  train_dataset=train_split,
                  eval_dataset=validation_split,
                  compute_metrics=compute_metrics,
                  tokenizer=injected_tokenizer,
                  model_init=model_init,
                  data_collator=data_collator)

def optuna_hp_space(trial):
    return {"learning_rate": trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True),
            "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64])}

best_trial = trainer.hyperparameter_search(direction="maximize",
                                           backend="optuna",
                                           hp_space=optuna_hp_space,
                                           n_trials=5)
print(best_trial)
# '''

training_args = Seq2SeqTrainingArguments(output_dir="~/path/",
                                         evaluation_strategy="no",
                                         num_train_epochs=5,
                                         learning_rate=2e-5,
                                         weight_decay=0.01, 
                                         per_device_train_batch_size=8,
                                         per_device_eval_batch_size=8,
                                         fp16=True, 
                                         push_to_hub=False)

trainer = Seq2SeqTrainer(final_model,
                         training_args,
                         train_dataset=encoded_train, 
                         eval_dataset=encoded_test,
                         data_collator=data_collator,
                         tokenizer=injected_tokenizer,
                         compute_metrics=compute_metrics)

trainer.train()
trainer.evaluate(encoded_test)