In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, re, json
import pandas as pd

import torch, numpy as np
from transformers import AutoTokenizer

In [None]:
# Load Model & Tokenizer
low_cpu_mem_usage = True
torch.set_grad_enabled(False)
device="cuda"

model_name = r"EleutherAI/gpt-j-6b"

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

### Dataset Construction Code

In [None]:
def construct_single_dataset(word_set, label_rule = lambda x,n: x[n], world_size=5, train_size=500, test_size=100, choose_n=1, allow_duplicates=False):
    """Simple dataset where targets are created based on defined label rule. 
    The default label rule is "choose position n" out of the possible slots in world_size"""
    sample_N = lambda x, N: list(np.random.choice(x, N, replace=allow_duplicates))    
    dataset = []
    label_indices = []
    assert (choose_n < world_size)

    for i in range(train_size + test_size):
        sample = sample_N(word_set, world_size)
        label = label_rule(sample, choose_n)
        label_index = sample.index(label)
    
        dataset.append({'input':sample, 'output':label})
        label_indices.append(label_index)

    return dataset[:train_size], dataset[train_size:], label_indices

def construct_mixed_dataset(target_word_set, distractor_word_set, world_size=5, train_size=500, test_size=100, choose_n=1, allow_duplicates=False):
    """Simple dataset where choose_n targets are chosen from one dataset, and distractors chosen from another"""
    sample_N = lambda x, N: list(np.random.choice(x, N, replace=allow_duplicates))
    n_distractors = world_size - choose_n

    dataset = []
    label_indices = []
    assert (choose_n < world_size)

    for i in range(train_size + test_size):
        distractors = sample_N(distractor_word_set, n_distractors)
        target = sample_N(target_word_set, choose_n)
        sample = list(np.random.permutation(distractors + target))
        label = target[0]
        label_index = sample.index(label)
    
        dataset.append({'input':sample, 'output':label})
        label_indices.append(label_index)

    return dataset[:train_size], dataset[train_size:], label_indices

def construct_rule_dataset(word_set, label_rule):        
    dataset = []
    size = len(word_set)
    for i in range(size):
        sample = word_set[i]
        label = label_rule(sample, choose_n)    
        dataset.append({'input':sample, 'output':label})

    return dataset


# Label Rules:
def choose_n(x, n):
    """Returns the word from x at position n"""
    return x[n]

def alphabetically_first(x,n):
    """Returns the word from x that appears alphabetically first, ties are broken by choosing lower index"""
    return min(x)

def alphabetically_last(x,n):
    """Returns the word from x that appears alphabetically last, ties are broken by choosing lower index"""
    return max(x)

def most_vowels(x,n):
    """Returns word from list x with the most vowels, ties are broken by choosing lower index"""
    return max(x, key=lambda y: len(re.findall(r'[aeiouAEIOU]', y)))

def longest_word(x,n):
    """Returns the longest word from list x, ties are broken by choosing lower index"""
    return sorted(x, key=len)[-1]

def shortest_word(x,n):
    """Returns the shortest word from list x, ties are broken by choosing lower index"""
    return sorted(x, key=len)[0]

def capitalize_first_letter(x, n):
    """Returns the first letter of the word, capitalized"""
    if isinstance(x,list):
        x = x[0]
    return x.title().strip()[0]

def capitalize(x, n):
    """Returns the word capitalized"""
    if isinstance(x,list):
        x = x[0]
    return x.strip().title()

def len_word(x,n):
    return str(len(x.strip()))

def next_capital_letter(x,n):
    """Returns the letter after the first letter of the word, capitalized"""
    def next_alpha(s):
        return chr((ord(s.upper())+1 - 65) % 26 + 65)
    
    if isinstance(x, list):
        x = x[0]
    return next_alpha(x.title().strip()[0])

def next_letter(x,n):
    """Returns the letter after the first letter of the word, capitalized"""
    def next_alpha(s):
        return chr((ord(s.upper())+1 - 65) % 26 + 65)
    
    if isinstance(x, list):
        x = x[0]
    return next_alpha(x.strip()[0])

def capitalize_last_letter(x,n):
    if isinstance(x,list):
        x = x[0]
    return x.strip()[-1].title()

def capitalize_second_letter(x,n):
    if isinstance(x,list):
        x = x[0]
    return x.strip()[1].title()

def lowercase_first_letter(x, n):
    if isinstance(x,list):
        x = x[0]
    return x.lower().strip()[0]

def lowercase_last_letter(x, n):
    if isinstance(x,list):
        x = x[0]
    return x.lower().strip()[-1]

def parens(x, n):
    """Returns the word capitalized"""
    if isinstance(x,list):
        x = x[0]
    return '('+x.strip()+')'

### Create New Choose Item from List Datasets

In [None]:
categories = json.load(open('categories.json', 'r'))
print(categories.keys())
big_list = []
for x in list(categories.keys()):
    big_list.extend(categories[x])

big_list = list(set(big_list))

In [None]:
for world_size in [3,5]:
    # Indexed Datasets
    train_dataset, _,_ = construct_single_dataset(big_list, choose_n=0, train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    ##json.dump(train_dataset, open(f'dataset_files/extractive/choose_first_of_{world_size}.json','w'))

    train_dataset, _,_ = construct_single_dataset(big_list, choose_n=world_size//2, train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/choose_middle_of_{world_size}.json','w'))

    train_dataset, _,_ = construct_single_dataset(big_list, choose_n=-1, train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/choose_last_of_{world_size}.json','w'))

    train_dataset, _,_ = construct_single_dataset(big_list, label_rule=alphabetically_first, train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/alphabetically_first_{world_size}.json','w'))

    train_dataset, _,_ = construct_single_dataset(big_list, label_rule=alphabetically_last, train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/alphabetically_last_{world_size}.json', 'w'))


    # Mixed/Distractor Datasets
    train_dataset, _,_ = construct_mixed_dataset(categories['object'], categories['verb'] + categories['adjective'] + categories['preposition'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/object_v_concept_{world_size}.json', 'w'))

    train_dataset, _,_ = construct_mixed_dataset(categories['verb'] + categories['adjective'] + categories['preposition'],categories['object'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/concept_v_object_{world_size}.json', 'w'))

    train_dataset, _,_ = construct_mixed_dataset(categories['fruit'], categories['animal'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/fruit_v_animal_{world_size}.json', 'w'))

    train_dataset, _,_ = construct_mixed_dataset(categories['color'], categories['animal'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/color_v_animal_{world_size}.json', 'w'))

    train_dataset, _,_ = construct_mixed_dataset(categories['animal'], categories['object'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/animal_v_object_{world_size}.json', 'w'))

    train_dataset, _,_ = construct_mixed_dataset(categories['verb'], categories['adjective'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/verb_v_adjective_{world_size}.json', 'w'))

    train_dataset, _,_ = construct_mixed_dataset(categories['adjective'], categories['verb'], train_size=1000, test_size=0, world_size=world_size)
    train_dataset = [{'input':", ".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]
    #json.dump(train_dataset, open(f'dataset_files/extractive/adjective_v_verb_{world_size}.json', 'w'))

In [None]:
# DATASETS:
# word_length, next_capital_letter, capitalize_last_letter, capitalize_second_letter, lowercase_first_letter, lowercase_last_letter

train_dataset, _ = construct_rule_dataset([x.upper() for x in big_list], label_rule=lowercase_first_letter)
#json.dump(train_dataset, open(f'dataset_files/unused/lowercase_last_letter.json','w'))

train_dataset, _ = construct_rule_dataset([x.upper() for x in big_list], label_rule=lowercase_last_letter)
#json.dump(train_dataset, open(f'dataset_files/unused/lowercase_last_letter.json','w'))

train_dataset, _ = construct_rule_dataset(big_list, label_rule=capitalize_last_letter)
#json.dump(train_dataset, open(f'dataset_files/unused/capitalize_last_letter.json','w'))

train_dataset, _ = construct_rule_dataset(big_list, label_rule=next_capital_letter)
#json.dump(train_dataset, open(f'dataset_files/unused/next_capital_letter.json','w'))

train_dataset, _ = construct_rule_dataset(big_list, label_rule=len_word)
#json.dump(train_dataset, open(f'dataset_files/unused/word_length.json','w'))

train_dataset = construct_rule_dataset(big_list, label_rule=capitalize_first_letter)
#json.dump(train_dataset, open(f'dataset_files/abstractive/capitalize_first_letter.json', 'w'))

train_dataset = construct_rule_dataset(big_list, label_rule=capitalize)
#json.dump(train_dataset, open(f'dataset_files/abstractive/capitalize.json', 'w'))

train_dataset = construct_rule_dataset(big_list, label_rule=parens)
#json.dump(train_dataset, open(f'dataset_files/abstractive/parens.json', 'w'))


big_list2 = []
big_list2 = list(filter(lambda x: len(x) > 1, big_list))
train_dataset, _ = construct_rule_dataset(big_list2, label_rule=capitalize_second_letter)
#json.dump(train_dataset, open(f'dataset_files/unused/capitalize_second_letter.json','w'))

## NLP Datasets

### CONLL2003

In [None]:
from datasets import load_dataset as ld
import re
import json
import numpy as np

dataset = ld("conll2003")

conll_label_dict = {'person':{"B_ind": 1, "I_ind":2}, 
                    'organization':{"B_ind": 3, "I_ind":4}, 
                    'location':{"B_ind": 5, "I_ind":6}}

re_test = re.compile(r"[\s]([,\.])")

for category in conll_label_dict.keys():
    B_ind = conll_label_dict[category]['B_ind']
    I_ind = conll_label_dict[category]['I_ind']

    data_filtered = []
    n_train, n_val = len(dataset['train']), len(dataset['validation'])

    for i in range(n_train + n_val):
        if i < n_train:
            data_point_i = dataset['train'][i]
        else:
            data_point_i = dataset['validation'][i - n_train]
        # print(data_point_i)

        if B_ind in data_point_i['ner_tags']:
            tag_counts = np.unique(data_point_i['ner_tags'], return_counts=True)
            tag_dict = {k:v for k,v in zip(tag_counts[0], tag_counts[1])}
            
            if tag_dict[B_ind] == 1: # Filter to sentences with only 1 appearance of B_ind
                input_text = " ".join(data_point_i['tokens'])
                input_text = re_test.sub(r"\1", input_text)
                output = data_point_i['tokens'][data_point_i['ner_tags'].index(B_ind)]
                if I_ind in data_point_i['ner_tags']:
                    name_cont_ind = list(np.where(np.array(data_point_i['ner_tags']) == I_ind)[0])
                    output_2 = " ".join([data_point_i['tokens'][x] for x in name_cont_ind])
                    output += " " + output_2
                
                data_filtered.append({"input":input_text, "output":output})
    # json.dump(data_filtered, open(f'dataset_files/conll2003/conll2003_{category}.json', 'w'))

### CommonsenseQA

In [None]:
from datasets import load_dataset as ld
# import pandas as pd
import numpy as np

dataset = ld("commonsense_qa", 'plain_text')
df = pd.DataFrame(dataset['train'])
df2 = pd.DataFrame(dataset['validation'])
df = pd.concat([df, df2], axis=0)
df['choices_label'] = df.choices.apply(lambda x: [y.lower() for y in  x['label']])
df['choices_text'] = df.choices.apply(lambda x: x['text'])
df['input'] = df.apply(lambda x: x['question'] + '\n' + "\n".join([(y[0] + ': ' + y[1]) for y in zip(x.choices_label, x.choices_text)]), axis=1)
df['output'] = df['answerKey'].str.lower()
df = df.drop(columns=['choices', 'id', 'question_concept', 'choices_label', 'choices_text', 'question', 'answerKey'])
df = df.reset_index(drop=True)

combined_data = [{'input':df.iloc[i].input, 'output':df.iloc[i].output} for i in range(len(df))]
# json.dump(combined_data, open('dataset_files/abstractive/commonsense_qa.json', 'w'))

### AG News

In [None]:
dataset = ld('ag_news')
df = pd.DataFrame(dataset['test'])
df = df.rename(columns={'text':'input', 'label':'output'})
combined_data = [{'input':df.iloc[i].input, 'output':int(df.iloc[i].output)} for i in range(len(df))]
# json.dump(combined_data, open('dataset_files/abstractive/ag_news.json', 'w'))

## Additional Datasets Considered

### SQUAD

In [None]:
from datasets import load_dataset as ld
import pandas as pd
import numpy as np

dataset = ld('squad')
df = pd.DataFrame(dataset['validation'])
df = df[df.apply(lambda x: len(np.unique(x['answers']['text']))==1, axis=1)] # filter to cases where answer is unique
df['input'] = df.apply(lambda x: x['context'] + '\n' + x['question'], axis=1)
df['output'] = df.apply(lambda x: x['answers']['text'][0], axis=1)
df = df.drop(columns=['id', 'title', 'answers', 'context', 'question']).reset_index(drop=True)

combined_data = [{'input':df.iloc[i].input, 'output':df.iloc[i].output} for i in range(len(df))]
# len(combined_data)
# json.dump(combined_data, open('dataset_files/abstractive/squad_val.json', 'w'))

### MultiNERD Simple NER Datasets

In [None]:
from datasets import load_dataset

dataset = load_dataset("tner/multinerd", 'en')

mn_label_dict = {'person':{"B_ind": 1, "I_ind":2}, 
                 'location':{"B_ind": 3, "I_ind":4}, 
                 'organization':{"B_ind": 5, "I_ind":6}}
                #  'animal':{"B_ind": 7, "I_ind":8},
                #  'celestial':{"B_ind": 11, "I_ind":12},
                #  'disease':{"B_ind": 13, "I_ind":14},
                #  'event':{"B_ind": 15, "I_ind":16},
                #  'food':{"B_ind": 17, "I_ind":18}}

re_test = re.compile(r"[\s]([,\.])")

for category in mn_label_dict.keys():
    B_ind = mn_label_dict[category]['B_ind']
    I_ind = mn_label_dict[category]['I_ind']

    data_filtered = []

    for i in range(len(dataset['test'])):
        data_point_i = dataset['test'][i]

        if B_ind in data_point_i['tags']:
            tag_counts = np.unique(data_point_i['tags'], return_counts=True)
            tag_dict = {k:v for k,v in zip(tag_counts[0], tag_counts[1])}
            
            if tag_dict[B_ind] == 1: # Filter to sentences with only 1 appearance of B_ind
                input_text = " ".join(data_point_i['tokens'])
                input_text = re_test.sub(r"\1", input_text)
                output = data_point_i['tokens'][data_point_i['tags'].index(B_ind)]
                if I_ind in data_point_i['tags']:
                    name_cont_ind = list(np.where(np.array(data_point_i['tags']) == I_ind)[0])
                    output_2 = " ".join([data_point_i['tokens'][x] for x in name_cont_ind])
                    output += " " + output_2
                
                data_filtered.append({"input":input_text, "output":output})

    # json.dump(data_filtered, open(f'dataset_files/multinerd/multinerd_{category}.json', 'w'))

### Next & Prev

In [None]:
lower = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
double_lower = [x*2 for x in lower]

upper = [x.upper() for x in lower]
double_upper = [x.upper()*2 for x in lower]

number_strings = ["zero","one", "two", "three", "four", "five",
           "six", "seven", "eight", "nine", "ten",
           "eleven", "twelve", "thirteen", "fourteen", "fifteen",
           "sixteen", "seventeen", "eighteen", "nineteen", "twenty"]

numbers = [str(i) for i in range(30)]

roman_numerals_upper = ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X',
                        'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI', 'XVII', 'XVIII', 'XIX', 'XX']

roman_numerals_lower = [x.lower() for x in roman_numerals_upper]

days = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']
days_upper = [x.title() for x in days]

months = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december']
months_upper = [x.title() for x in months]


pair_up_next = lambda l : [{"input":l[i], "output":l[i+1]} for i in range(len(l)-1)]
pair_up_prev = lambda l : [{"input":l[i+1], "output":l[i]} for i in range(len(l)-1)]

PREV_dataset = pair_up_prev(number_strings) + pair_up_prev(numbers) + pair_up_prev(lower) + pair_up_prev(upper) + pair_up_prev(double_upper) + pair_up_prev(double_lower) +  pair_up_prev(roman_numerals_upper) + pair_up_prev(roman_numerals_lower) + pair_up_prev(days) +  pair_up_prev(months) +  pair_up_prev(days_upper)  + pair_up_prev(months_upper)  + [{'input':'monday', 'output':'sunday'}] +  [{'input':'january', 'output':'december'}] +  [{'input':'Monday', 'output':'Sunday'}] +  [{'input':'January', 'output':'December'}]
NEXT_dataset = pair_up_next(number_strings) + pair_up_next(numbers) + pair_up_next(lower) + pair_up_next(upper) + pair_up_next(double_upper) + pair_up_next(double_lower) + pair_up_next(roman_numerals_upper) + pair_up_next(roman_numerals_lower) + pair_up_next(days) + pair_up_next(months)+ pair_up_next(days_upper) + pair_up_next(months_upper) + [{'input':'sunday', 'output':'monday'}] +  [{'input':'december', 'output':'january'}] +  [{'input':'Sunday', 'output':'Monday'}] +  [{'input':'December', 'output':'January'}]

# json.dump(NEXT_dataset, open('dataset_files/abstractive/next_item.json','w'))
# json.dump(PREV_dataset, open('dataset_files/abstractive/prev_item.json','w'))