In [None]:
import json
import numpy as np
import pandas as pd

## Cache calinet probing data

In [None]:
with open('../data/calinet_probing_data_original/probing_data_trex_500each.json', 'r') as f:
    data_calinet = json.load(f)
    starter_df = pd.DataFrame(list(data_calinet['data']))

In [None]:
# starter df
starter_df.head()

In [None]:
# all of these have to do with fact id 1
# the sentences are formed in this format...
# the start of a factual sentence, involving the subject
# and then two possibilities: one true and one false?
# storing these, then, we should do something like
# sentence stem | correct | incorrect
# and we can strip out the <extra_id_x> parts
# to keep it model agnostic
starter_df['sentences'][0][0]

In [None]:
# create containers to hold our clean data
sentence_stems = []
correct = []
incorrect = []
fact_ids = []
relations = []
subjects = []
objects = []

In [None]:
for index, row in starter_df.iterrows():
    sentence_list = row['sentences']
    for entry in sentence_list:
        
        # minor cleanup 
        cleaned_stem = entry[0].replace("<extra_id_0>", "[BLANK]").strip()
        cleaned_correct = entry[1].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        cleaned_incorrect = entry[2].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        
        # grab sub<->obj
        subjects_and_objects = pd.json_normalize(row['triplet'])
        subjects.append(subjects_and_objects.sub_label.values[0])
        objects.append(subjects_and_objects.obj_label.values[0])
        
        # commit 
        sentence_stems.append(cleaned_stem)
        correct.append(cleaned_correct)
        incorrect.append(cleaned_incorrect)
        fact_ids.append(row['fact_id'])
        relations.append(row['relation'])

In [None]:
# sanity check
assert(len(sentence_stems) ==
       len(correct) ==
       len(incorrect) ==
       len(fact_ids) ==
       len(relations) ==
      len(subjects) ==
      len(objects))

In [None]:
# merge into big df
trex_df = pd.DataFrame({'fact_id': fact_ids,
                        'relation': relations, 'subject': subjects,
                        'object': objects, 'stem': sentence_stems, 'true': correct,
                        'false': incorrect})

In [None]:
# full df
trex_df.head()

In [None]:
trex_df.tail()

In [None]:
trex_df.shape

In [None]:
# write out initial df
trex_df.to_json('../data/calinet_probing_data_original/calinet_trex_full_data.json', orient='records', lines=True)

In [None]:
# put false inputs into a list
# with open('../data/calinet_probing_data_original/calinet_trex_full_data.json', 'r') as f:
    # data_calinet = json.load(f)

In [None]:
# how many stems end in [BLANK]? -> 50451, or about 1/3.
c = 0
for stem in trex_df['stem']:
    if stem.endswith("[BLANK]."):
        c+=1
print(c)

In [None]:
def check_for_causal_compatibility(stem):
    return stem.endswith("[BLANK].")

In [None]:
def trim_stem(stem):
    if stem.endswith("[BLANK]."):
        return stem[0: len(stem)-9]

In [None]:
trex_causal_df = trex_df[trex_df.apply(lambda x: check_for_causal_compatibility(x.stem), axis=1)]

In [None]:
trex_causal_df = trex_causal_df.copy()

In [None]:
trimmed_stems = trex_causal_df.apply(lambda x: trim_stem(x.stem), axis=1)

In [None]:
trex_causal_df['stem'] = list(trimmed_stems)

In [None]:
# only about 20% of the calinet data is 'unique' knowledge, since they used paraphrases to calibrate
len(trex_causal_df['fact_id'].unique())

In [None]:
# before sampling, attach arbitrary counter ID, to then track who gets removed
trex_causal_df['calibra_id'] = range(50451)

In [None]:
trex_causal_subset = trex_causal_df.groupby('fact_id').apply(lambda x: x.sample(1, random_state=42)).reset_index(drop=True)

In [None]:
assert(trex_causal_subset.shape[0] == len(trex_causal_df['fact_id'].unique()))


In [None]:
trex_causal_subset.head()


In [None]:
trex_causal_subset.tail()


In [None]:
removed_ids = {}
removed_counterfacts = {}
for c_id in trex_causal_df['calibra_id']:
    if c_id not in trex_causal_subset['calibra_id'].values:
        fact_id = trex_causal_df[trex_causal_df['calibra_id'] == c_id]['fact_id'].values[0]
        counterfact = trex_causal_df[trex_causal_df['calibra_id'] == c_id]['false'].values[0]
        removed_ids[str(c_id)] = int(fact_id)
        if str(fact_id) in removed_counterfacts:
            removed_counterfacts[str(fact_id)].append(counterfact)
        else:
            removed_counterfacts[str(fact_id)] = [counterfact]

# did we remove as many rows as eq to the difference between the full calinet dataset row number and the unique count?
assert(len(removed_ids) == trex_causal_df.shape[0] - len(trex_causal_df['fact_id'].unique()))

In [None]:
# these are essentially the extra false things we can test against
# that are still worth keeping
c = 0
for k, v in removed_counterfacts.items():
    print(k, v)
    c+=1
    if c == 15:
        break

In [None]:
# drop extraneous calibra_id column 
trex_causal_subset.drop(['calibra_id'], axis=1, inplace=True)


In [None]:
# there are some fact_id's that only have 1 row
# since we did pull stuff out based on our left to right requirement
trex_causal_subset.shape

In [None]:
len(removed_counterfacts)

In [None]:
full_falses = {}
for k, v in removed_counterfacts.items():
    subset_false = trex_causal_subset[trex_causal_subset['fact_id'] == int(k)].false.values[0]
    full_falses[k] = v
    full_falses[k].append(subset_false)

print(len(full_falses))

In [None]:
for k, v in full_falses.items():
    print(k,v)
    break

In [None]:
def replace_false_column(fact_id, false_val, full_false_dict=full_falses):
    if str(fact_id) in full_false_dict:
        return full_false_dict[str(fact_id)]
    else:
        return [false_val]

In [None]:
replaced_falses = list(trex_causal_subset.apply(lambda x: replace_false_column(x.fact_id, x.false), axis=1))


In [None]:
len(replaced_falses)


In [None]:
replaced_falses[:6]


In [None]:
trex_causal_subset['false'] = replaced_falses


In [None]:
trex_causal_subset.head()

In [None]:
trex_causal_subset.tail()

In [None]:
output_dict = {}
trex_list = trex_causal_subset.to_dict('records')
for i, entry in enumerate(trex_list):
    output_dict[i] = trex_list[i]

In [None]:
num_pairs = 0
for x, y in output_dict.items():
    output_dict[x] = y 
    output_dict[x]['false'] = list(set(y['false']))
    
    num_pairs += len(output_dict[x]['false'])

In [None]:
num_pairs

In [None]:
# write out cleaned/formatted df
with open(
    f"../data/calinet_input_information.json", "w"
) as outfile:
    json.dump(output_dict, outfile)

In [None]:
# out of curiosity, which relation templates persist in the cleaned, 'causal friendly' set...
trex_causal_df['relation'].value_counts()

## Cache ROME counterfact data

In [None]:
with open('../data/rome_counterfact_original/counterfact.json', 'r') as f:
    data_rome = json.load(f)

In [None]:
len(data_rome)

In [None]:
data_rome_input_information = {}

for i in range(len(data_rome)):
    stem = data_rome[i]['requested_rewrite']['prompt'].replace('{}', data_rome[i]['requested_rewrite']['subject'])
    
    data_rome_input_information[str(i)] = {
        "stem": stem,
        "true": data_rome[i]['requested_rewrite']['target_true']['str'],
        "false": [data_rome[i]['requested_rewrite']['target_new']['str']],
        "case_id":  data_rome[i]['case_id']
    }

In [None]:
#data_rome_input_information

In [None]:
with open(
    f"../data/rome_counterfact_input_information.json", "w"
) as outfile:
    json.dump(data_rome_input_information, outfile)

## Combine the two datasets

In [None]:
with open('../data/calinet_input_information.json', 'r') as f:
    data_calinet = json.load(f)

with open('../data/rome_counterfact_input_information.json', 'r') as f:
    data_rome= json.load(f)


In [None]:
#data_calinet
#data_rome

mixed_itr = 0
mixed_df = {}

for x, y in data_calinet.items():
    y['dataset_original'] = 'calinet_input_information'
    mixed_df[str(mixed_itr)] = y

    mixed_itr+=1

for x, y in data_rome.items():
    y['dataset_original'] = 'rome_counterfact_input_information'
    mixed_df[str(mixed_itr)] = y
    mixed_itr+=1


In [None]:
itrs = 0
for x, y in mixed_df.items():
    itrs += 1

print(f'The number of items in mixed_df is {itrs}')

In [None]:
# check for duplicate stem, fact, counterfact pairs across the dataset:

pairs_list = []
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        pairs_list.append(pairs)

print(f'The number of [stem + fact + counterfact] trios in mixed_df is {len(pairs_list)}')
# num duplicates 
num_dup = len(pairs_list) - len(np.unique(np.array(pairs_list)))
print(f'The number of duplicated [stem + fact + counterfact] trios in mixed_df is {num_dup}')


In [None]:
# remove the 2 duplicates

pairs_list = []
dup_itr = 1
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        if pairs in pairs_list:
            print(x, dup_itr, pairs)
            dup_itr+=1
        pairs_list.append(pairs)


In [None]:
# hard code remove the above
mixed_df['10107']['false'] = ['Bulgaria']
mixed_df['10107']
del mixed_df['25020']


In [None]:
# check for duplicates again

pairs_list = []
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        pairs_list.append(pairs)

print(f'The number of [stem + fact + counterfact] trios in mixed_df is {len(pairs_list)}')
# num duplicates 
num_dup = len(pairs_list) - len(np.unique(np.array(pairs_list)))
print(f'The number of duplicated [stem + fact + counterfact] trios in mixed_df is {num_dup}')

In [None]:
# check for duplicates in stem + fact pairs -> combine duplicates to one item

x_list = []
pairs_list = []
pairs_dup_list = []
for x, y in mixed_df.items():
    pairs = [y['stem'] + ' ' + y['true']]
    if pairs in pairs_list:
        x_list.append(x)
        pairs_dup_list.append(pairs)

    pairs_list.append(pairs)

print(f'The number of [stem + fact] pairs in mixed_df is {len(pairs_list)}')
# num duplicates 
num_dup = len(pairs_list) - len(np.unique(np.array(pairs_list)))
print(f'The number of duplicated [stem + fact] pairs in mixed_df is {num_dup}')

In [None]:
for x in x_list:
    element = mixed_df[x]
    pairs = [element['stem'] + ' ' + element['true']]

    for x_2, y in mixed_df.items():
        pairs_2 = [y['stem'] + ' ' + y['true']]
        if (pairs == pairs_2) and (x != x_2):
            # extend x_2 counterfacts list with x counterfacts
            # grab the set so they are all unique items
            # del mixed_df[x] below
            mixed_df[x_2]['false'].extend(mixed_df[x]['false'])
            mixed_df[x_2]['false'] = list(set(mixed_df[x_2]['false']))


for x in x_list:
    del mixed_df[x]


In [None]:
# check for duplicates in stem + fact pairs again
# should be 0

x_list = []
pairs_list = []
pairs_dup_list = []
for x, y in mixed_df.items():
    pairs = [y['stem'] + ' ' + y['true']]
    if pairs in pairs_list:
        x_list.append(x)
        pairs_dup_list.append(pairs)

    pairs_list.append(pairs)

print(f'The number of [stem + fact] pairs in mixed_df is {len(pairs_list)}')
# num duplicates 
num_dup = len(pairs_list) - len(np.unique(np.array(pairs_list)))
print(f'The number of duplicated [stem + fact] pairs in mixed_df is {num_dup}')

In [None]:
# final numbers
pairs_list = []
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        pairs_list.append(pairs)

print(f'The number of [stem + fact + counterfact] trios in mixed_df is {len(pairs_list)}')

x_list = []
pairs_list = []
for x, y in mixed_df.items():
    pairs = [y['stem'] + ' ' + y['true']]
    pairs_list.append(pairs)

print(f'The number of [stem + fact] pairs in mixed_df is {len(pairs_list)}')

In [None]:
# final count of duplicates fact_id and case_id
case_id_list = []
fact_id_list = []

for x, y in mixed_df.items():
    try:
        case_id_list.append(y['case_id'])
    except:
        fact_id_list.append(y['fact_id'])

print(f'The number of duplicated case_ids is {len(case_id_list) - len(list(set(case_id_list)))}')
print(f'The number of duplicated fact_ids is {len(fact_id_list) - len(list(set(fact_id_list)))}')


In [None]:
# write the final mixed json to file

with open(
    f"../data/calibragpt_full_input_information.json", "w"
) as outfile:
    json.dump(mixed_df, outfile)

In [None]:
# test load of mixed json

with open(
    f"../data/calibragpt_full_input_information.json", "r"
) as outfile:
    mixed_df = json.load(outfile)

In [None]:
# final numbers test 2
pairs_list = []
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        pairs_list.append(pairs)

print(f'The number of [stem + fact + counterfact] trios in mixed_df is {len(pairs_list)}')

x_list = []
pairs_list = []
for x, y in mixed_df.items():
    pairs = [y['stem'] + ' ' + y['true']]
    pairs_list.append(pairs)

print(f'The number of [stem + fact] pairs in mixed_df is {len(pairs_list)}')

In [None]:
# update mixed_df to have all info for rome then write that out. 
mixed_df = pd.DataFrame.from_dict(mixed_df).T

In [None]:
# get rome info to look at:
with open('../data/rome_counterfact_original/counterfact.json', 'r') as f:
    data_rome = json.load(f)
    rome_df = pd.DataFrame.from_dict(data_rome)


In [None]:
# 3/20 data frame cleanup
rome_df.head()

rome_subjects = {}
rome_objects = {}
rome_relations = {}

for i, rewrite in enumerate(rome_df['requested_rewrite']):
    rome_subjects[i] = rewrite['subject']
    rome_objects[i] = rewrite['target_true']['str']
    rome_relations[i] = rewrite['relation_id']

assert(len(rome_subjects) == len(rome_objects) == len(rome_relations) == rome_df.shape[0])

In [None]:
subjects = []
objects = []
ids = []
relations = []

for row in mixed_df.iterrows():
    if row[1]['dataset_original'] == 'calinet_input_information':
        subjects.append(row[1]['subject'])
        objects.append(row[1]['object'])
        relations.append(row[1]['relation'])
        ids.append('calinet_' + str(row[1]['fact_id']))
    if row[1]['dataset_original'] == 'rome_counterfact_input_information':
        # get case id
        case_id = row[1]['case_id']
        
        # get subject
        subjects.append(rome_subjects[case_id])
        # get object
        objects.append(rome_objects[case_id])
        # get relation
        relations.append(rome_relations[case_id])
        ids.append('rome_' + str(case_id))

assert(len(subjects) == len(objects) == len(ids) == len(relations))

In [None]:
mixed_df['subject'] = subjects

In [None]:
mixed_df['object'] = objects

In [None]:
mixed_df['relation'] = relations

In [None]:
mixed_df['dataset_id'] = ids

In [None]:
mixed_df.drop(['fact_id', 'case_id', 'dataset_original'], axis=1, inplace=True)

In [None]:
assert(not mixed_df.isnull().values.any())

In [None]:
# write to file as .csv
mixed_df.to_csv('../data/calibragpt_full_input_information_3_20_23.csv', index=False)