In [None]:
import json
import numpy as np
import pandas as pd

## Cache calinet probing data

In [None]:
with open('../data/calinet_probing_data_original/probing_data_trex_500each.json', 'r') as f:
    data_calinet = json.load(f)
    starter_df = pd.DataFrame(list(data_calinet['data']))

In [None]:
# starter df
starter_df.head()

In [None]:
# all of these have to do with fact id 1
# the sentences are formed in this format...
# the start of a factual sentence, involving the subject
# and then two possibilities: one true and one false?
# storing these, then, we should do something like
# sentence stem | correct | incorrect
# and we can strip out the <extra_id_x> parts
# to keep it model agnostic
starter_df['sentences'][0][0]

In [None]:
# create containers to hold our clean data
sentence_stems = []
correct = []
incorrect = []
fact_ids = []
relations = []
subjects = []
objects = []

In [None]:
for index, row in starter_df.iterrows():
    sentence_list = row['sentences']
    for entry in sentence_list:
        
        # minor cleanup 
        cleaned_stem = entry[0].replace("<extra_id_0>", "[BLANK]").strip()
        cleaned_correct = entry[1].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        cleaned_incorrect = entry[2].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        
        # grab sub<->obj
        subjects_and_objects = pd.json_normalize(row['triplet'])
        subjects.append(subjects_and_objects.sub_label.values[0])
        objects.append(subjects_and_objects.obj_label.values[0])
        
        # commit 
        sentence_stems.append(cleaned_stem)
        correct.append(cleaned_correct)
        incorrect.append(cleaned_incorrect)
        fact_ids.append(row['fact_id'])
        relations.append(row['relation'])

In [None]:
# sanity check
assert(len(sentence_stems) ==
       len(correct) ==
       len(incorrect) ==
       len(fact_ids) ==
       len(relations) ==
      len(subjects) ==
      len(objects))

In [None]:
# merge into big df
trex_df = pd.DataFrame({'fact_id': fact_ids,
                        'relation': relations, 'subject': subjects,
                        'object': objects, 'stem': sentence_stems, 'true': correct,
                        'false': incorrect})

In [None]:
# full df
trex_df.head()

In [None]:
trex_df.tail()

In [None]:
trex_df.shape

In [None]:
# write out initial df
trex_df.to_json('../data/calinet_probing_data_original/calinet_trex_full_data.json', orient='records', lines=True)

In [None]:
# put false inputs into a list
# with open('../data/calinet_probing_data_original/calinet_trex_full_data.json', 'r') as f:
    # data_calinet = json.load(f)

In [None]:
# how many stems end in [BLANK]? -> 50451, or about 1/3.
c = 0
for stem in trex_df['stem']:
    if stem.endswith("[BLANK]."):
        c+=1
print(c)

In [None]:
def check_for_causal_compatibility(stem):
    return stem.endswith("[BLANK].")

In [None]:
def trim_stem(stem):
    if stem.endswith("[BLANK]."):
        return stem[0: len(stem)-9]

In [None]:
trex_causal_df = trex_df[trex_df.apply(lambda x: check_for_causal_compatibility(x.stem), axis=1)]

In [None]:
trex_causal_df = trex_causal_df.copy()

In [None]:
trimmed_stems = trex_causal_df.apply(lambda x: trim_stem(x.stem), axis=1)

In [None]:
trex_causal_df['stem'] = list(trimmed_stems)

In [None]:
output_dict = {}
trex_list = trex_causal_df.to_dict('records')
for i, entry in enumerate(trex_list):
    output_dict[i] = trex_list[i]

In [None]:
for x, y in output_dict.items():
    output_dict[x] = y 
    output_dict[x]['false'] = [y['false']] 


In [None]:
# write out cleaned/formatted df
with open(
    f"../data/calinet_input_information.json", "w"
) as outfile:
    json.dump(output_dict, outfile)

In [None]:
# out of curiosity, which relation templates persist in the cleaned, 'causal friendly' set...
trex_causal_df['relation'].value_counts()

## Cache ROME counterfact data

In [None]:
with open('../data/rome_counterfact_original/counterfact.json', 'r') as f:
    data_rome = json.load(f)

In [None]:
len(data_rome)

In [None]:
data_rome_input_information = {}

for i in range(len(data_rome)):
    stem = data_rome[i]['requested_rewrite']['prompt'].replace('{}', data_rome[i]['requested_rewrite']['subject'])
    
    data_rome_input_information[str(i)] = {
        "stem": stem,
        "true": data_rome[i]['requested_rewrite']['target_true']['str'],
        "false": [data_rome[i]['requested_rewrite']['target_new']['str']],
        "case_id":  data_rome[i]['case_id']
    }

In [None]:
data_rome_input_information

In [None]:
with open(
    f"../data/rome_counterfact_input_information.json", "w"
) as outfile:
    json.dump(data_rome_input_information, outfile)

## Combine the two datasets

In [None]:
with open('../data/calinet_input_information.json', 'r') as f:
    data_calinet = json.load(f)

with open('../data/rome_counterfact_input_information.json', 'r') as f:
    data_rome= json.load(f)


In [None]:
#data_calinet
#data_rome

mixed_itr = 0
mixed_df = {}

for x, y in data_calinet.items():
    y['dataset_original'] = 'calinet_input_information'
    mixed_df[str(mixed_itr)] = y

    mixed_itr+=1

for x, y in data_rome.items():
    y['dataset_original'] = 'rome_counterfact_input_information'
    mixed_df[str(mixed_itr)] = y
    mixed_itr+=1


In [None]:
itrs = 0
for x, y in mixed_df.items():
    itrs += 1

In [None]:
itrs

In [None]:

with open(
    f"../data/calibragpt_full_input_information.json", "w"
) as outfile:
    json.dump(mixed_df, outfile)