In [1]:
import json
import numpy as np
import pandas as pd

## Cache calinet probing data

In [2]:
with open('../data/calinet_probing_data_original/probing_data_trex_500each.json', 'r') as f:
    data_calinet = json.load(f)
    starter_df = pd.DataFrame(list(data_calinet['data']))

In [3]:
# starter df
starter_df.head()

Unnamed: 0,fact_id,relation,triplet,sentences
0,1,P47,"{'sub_label': 'Norfolk', 'obj_label': 'Suffolk'}","[[Norfolk shares border with <extra_id_0>., <e..."
1,2,P47,"{'sub_label': 'Jordan', 'obj_label': 'Israel'}","[[<extra_id_0> shares border with Israel., <ex..."
2,3,P47,"{'sub_label': 'Kenya', 'obj_label': 'Ethiopia'}","[[Kenya shares border with <extra_id_0>., <ext..."
3,4,P47,"{'sub_label': 'Egypt', 'obj_label': 'Israel'}","[[<extra_id_0> shares border with Israel., <ex..."
4,5,P47,"{'sub_label': 'Tanzania', 'obj_label': 'Uganda'}","[[Tanzania shares border with <extra_id_0>., <..."


In [4]:
# all of these have to do with fact id 1
# the sentences are formed in this format...
# the start of a factual sentence, involving the subject
# and then two possibilities: one true and one false?
# storing these, then, we should do something like
# sentence stem | correct | incorrect
# and we can strip out the <extra_id_x> parts
# to keep it model agnostic
starter_df['sentences'][0][0]

['Norfolk shares border with <extra_id_0>.',
 '<extra_id_0> Suffolk <extra_id_1>',
 '<extra_id_0> Upper Macungie Township <extra_id_1>']

In [5]:
# create containers to hold our clean data
sentence_stems = []
correct = []
incorrect = []
fact_ids = []
relations = []
subjects = []
objects = []

In [6]:
for index, row in starter_df.iterrows():
    sentence_list = row['sentences']
    for entry in sentence_list:
        
        # minor cleanup 
        cleaned_stem = entry[0].replace("<extra_id_0>", "[BLANK]").strip()
        cleaned_correct = entry[1].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        cleaned_incorrect = entry[2].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        
        # grab sub<->obj
        subjects_and_objects = pd.json_normalize(row['triplet'])
        subjects.append(subjects_and_objects.sub_label.values[0])
        objects.append(subjects_and_objects.obj_label.values[0])
        
        # commit 
        sentence_stems.append(cleaned_stem)
        correct.append(cleaned_correct)
        incorrect.append(cleaned_incorrect)
        fact_ids.append(row['fact_id'])
        relations.append(row['relation'])

In [7]:
# sanity check
assert(len(sentence_stems) ==
       len(correct) ==
       len(incorrect) ==
       len(fact_ids) ==
       len(relations) ==
      len(subjects) ==
      len(objects))

In [8]:
# also generate a 'calibra id' just to give us a primary key
calibra_id = np.arange(0, len(sentence_stems))

In [9]:
# merge into big df
trex_df = pd.DataFrame({'id': calibra_id, 'fact_id': fact_ids,
                        'relation': relations, 'subject': subjects,
                        'object': objects, 'stem': sentence_stems, 'true': correct,
                        'false': incorrect})

In [10]:
# full df
trex_df.head()

Unnamed: 0,id,fact_id,relation,subject,object,stem,true,false
0,0,1,P47,Norfolk,Suffolk,Norfolk shares border with [BLANK].,Suffolk,Upper Macungie Township
1,1,1,P47,Norfolk,Suffolk,Norfolk borders with [BLANK].,Suffolk,Vadena
2,2,1,P47,Norfolk,Suffolk,[BLANK] shares the border with Suffolk.,Norfolk,Northern Cape province
3,3,1,P47,Norfolk,Suffolk,[BLANK] shares its border with Suffolk.,Norfolk,Sunamganj District
4,4,1,P47,Norfolk,Suffolk,Norfolk shares a common border with [BLANK].,Suffolk,Anabar


In [11]:
trex_df.shape

(143000, 8)

In [12]:
# write out initial df
trex_df.to_json('../data/calinet_trex_full_data.json', orient='records', lines=True)

In [13]:
# how many stems end in [BLANK]? -> 50451, or about 1/3.
c = 0
for stem in trex_df['stem']:
    if stem.endswith("[BLANK]."):
        c+=1
print(c)

50451


In [14]:
def check_for_causal_compatibility(stem):
    return stem.endswith("[BLANK].")

In [15]:
def trim_stem(stem):
    if stem.endswith("[BLANK]."):
        return stem[0: len(stem)-8]

In [16]:
trex_causal_df = trex_df[trex_df.apply(lambda x: check_for_causal_compatibility(x.stem), axis=1)]

In [17]:
trex_causal_df = trex_causal_df.copy()

In [18]:
trimmed_stems = trex_causal_df.apply(lambda x: trim_stem(x.stem), axis=1)

In [19]:
trex_causal_df['stem'] = list(trimmed_stems)

In [20]:
# write out cleaned/formatted df
trex_causal_df.to_json('../data/calinet_trex_causal_data.json', orient='records', lines=True)

## Cache ROME counterfact data

In [4]:
with open('../data/rome_counterfact_original/counterfact.json', 'r') as f:
    data_rome = json.load(f)

In [5]:
len(data_rome)

21919

In [6]:
data_rome_input_information = {}

for i in range(len(data_rome)):
    stem = data_rome[i]['requested_rewrite']['prompt'].replace('{}', data_rome[i]['requested_rewrite']['subject'])
    
    data_rome_input_information[str(i)] = {
        "stem": stem,
        "true": data_rome[i]['requested_rewrite']['target_true']['str'],
        "false": [data_rome[i]['requested_rewrite']['target_new']['str']],
        "case_id":  data_rome[i]['case_id']
    }

In [None]:
data_rome_input_information

In [8]:
with open(
    f"../data/rome_counterfact_input_information.json", "w"
) as outfile:
    json.dump(data_rome_input_information, outfile)