In [60]:
import json
import numpy as np
import pandas as pd

## Cache calinet probing data

In [61]:
with open('../data/calinet_probing_data_original/probing_data_trex_500each.json', 'r') as f:
    data_calinet = json.load(f)
    starter_df = pd.DataFrame(list(data_calinet['data']))

In [62]:
# starter df
starter_df.head()

Unnamed: 0,fact_id,relation,triplet,sentences
0,1,P47,"{'sub_label': 'Norfolk', 'obj_label': 'Suffolk'}","[[Norfolk shares border with <extra_id_0>., <e..."
1,2,P47,"{'sub_label': 'Jordan', 'obj_label': 'Israel'}","[[<extra_id_0> shares border with Israel., <ex..."
2,3,P47,"{'sub_label': 'Kenya', 'obj_label': 'Ethiopia'}","[[Kenya shares border with <extra_id_0>., <ext..."
3,4,P47,"{'sub_label': 'Egypt', 'obj_label': 'Israel'}","[[<extra_id_0> shares border with Israel., <ex..."
4,5,P47,"{'sub_label': 'Tanzania', 'obj_label': 'Uganda'}","[[Tanzania shares border with <extra_id_0>., <..."


In [63]:
# all of these have to do with fact id 1
# the sentences are formed in this format...
# the start of a factual sentence, involving the subject
# and then two possibilities: one true and one false?
# storing these, then, we should do something like
# sentence stem | correct | incorrect
# and we can strip out the <extra_id_x> parts
# to keep it model agnostic
starter_df['sentences'][0][0]

['Norfolk shares border with <extra_id_0>.',
 '<extra_id_0> Suffolk <extra_id_1>',
 '<extra_id_0> Upper Macungie Township <extra_id_1>']

In [64]:
# create containers to hold our clean data
sentence_stems = []
correct = []
incorrect = []
fact_ids = []
relations = []
subjects = []
objects = []

In [65]:
for index, row in starter_df.iterrows():
    sentence_list = row['sentences']
    for entry in sentence_list:
        
        # minor cleanup 
        cleaned_stem = entry[0].replace("<extra_id_0>", "[BLANK]").strip()
        cleaned_correct = entry[1].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        cleaned_incorrect = entry[2].replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
        
        # grab sub<->obj
        subjects_and_objects = pd.json_normalize(row['triplet'])
        subjects.append(subjects_and_objects.sub_label.values[0])
        objects.append(subjects_and_objects.obj_label.values[0])
        
        # commit 
        sentence_stems.append(cleaned_stem)
        correct.append(cleaned_correct)
        incorrect.append(cleaned_incorrect)
        fact_ids.append(row['fact_id'])
        relations.append(row['relation'])

In [66]:
# sanity check
assert(len(sentence_stems) ==
       len(correct) ==
       len(incorrect) ==
       len(fact_ids) ==
       len(relations) ==
      len(subjects) ==
      len(objects))

In [67]:
# merge into big df
trex_df = pd.DataFrame({'fact_id': fact_ids,
                        'relation': relations, 'subject': subjects,
                        'object': objects, 'stem': sentence_stems, 'true': correct,
                        'false': incorrect})

In [68]:
# full df
trex_df.head()

Unnamed: 0,fact_id,relation,subject,object,stem,true,false
0,1,P47,Norfolk,Suffolk,Norfolk shares border with [BLANK].,Suffolk,Upper Macungie Township
1,1,P47,Norfolk,Suffolk,Norfolk borders with [BLANK].,Suffolk,Vadena
2,1,P47,Norfolk,Suffolk,[BLANK] shares the border with Suffolk.,Norfolk,Northern Cape province
3,1,P47,Norfolk,Suffolk,[BLANK] shares its border with Suffolk.,Norfolk,Sunamganj District
4,1,P47,Norfolk,Suffolk,Norfolk shares a common border with [BLANK].,Suffolk,Anabar


In [69]:
trex_df.tail()

Unnamed: 0,fact_id,relation,subject,object,stem,true,false
142995,13000,P264,X&Y,Parlophone,[BLANK] label : Parlophone.,X&Y,Junior Hanson
142996,13000,P264,X&Y,Parlophone,"[BLANK], released by Parlophone.",X&Y,Doo-Wops & Hooligans
142997,13000,P264,X&Y,Parlophone,Parlophone recording artist [BLANK].,X&Y,Atlas Genius
142998,13000,P264,X&Y,Parlophone,Parlophone artists such as [BLANK].,X&Y,untitled 2008 album
142999,13000,P264,X&Y,Parlophone,[BLANK] artists including X&Y.,Parlophone,ATCO


In [70]:
trex_df.shape

(143000, 7)

In [71]:
# write out initial df
trex_df.to_json('../data/calinet_probing_data_original/calinet_trex_full_data.json', orient='records', lines=True)

In [72]:
# put false inputs into a list
# with open('../data/calinet_probing_data_original/calinet_trex_full_data.json', 'r') as f:
    # data_calinet = json.load(f)

In [73]:
# how many stems end in [BLANK]? -> 50451, or about 1/3.
c = 0
for stem in trex_df['stem']:
    if stem.endswith("[BLANK]."):
        c+=1
print(c)

50451


In [74]:
def check_for_causal_compatibility(stem):
    return stem.endswith("[BLANK].")

In [75]:
def trim_stem(stem):
    if stem.endswith("[BLANK]."):
        return stem[0: len(stem)-9]

In [76]:
trex_causal_df = trex_df[trex_df.apply(lambda x: check_for_causal_compatibility(x.stem), axis=1)]

In [77]:
trex_causal_df = trex_causal_df.copy()

In [78]:
trimmed_stems = trex_causal_df.apply(lambda x: trim_stem(x.stem), axis=1)

In [79]:
trex_causal_df['stem'] = list(trimmed_stems)

In [80]:
# only about 20% of the calinet data is 'unique' knowledge, since they used paraphrases to calibrate
len(trex_causal_df['fact_id'].unique())

11960

In [81]:
# before sampling, attach arbitrary counter ID, to then track who gets removed
trex_causal_df['calibra_id'] = range(50451)

In [82]:
trex_causal_subset = trex_causal_df.groupby('fact_id').apply(lambda x: x.sample(1, random_state=42)).reset_index(drop=True)

In [83]:
assert(trex_causal_subset.shape[0] == len(trex_causal_df['fact_id'].unique()))


In [84]:
trex_causal_subset.head()


Unnamed: 0,fact_id,relation,subject,object,stem,true,false,calibra_id
0,1,P47,Norfolk,Suffolk,Norfolk borders with,Suffolk,Vadena,1
1,2,P47,Jordan,Israel,Jordan shares a common border with,Israel,Simbach,6
2,3,P47,Kenya,Ethiopia,Kenya shares border with,Ethiopia,Yixing,7
3,4,P47,Egypt,Israel,Egypt shares its border with,Israel,"Montréal, Quebec",14
4,5,P47,Tanzania,Uganda,Tanzania borders with,Uganda,La Pampa,19


In [85]:
trex_causal_subset.tail()


Unnamed: 0,fact_id,relation,subject,object,stem,true,false,calibra_id
11955,12996,P264,Cody Wise,Interscope Records,The music label that is representing Cody Wise is,Interscope Records,Heads Up International,50411
11956,12997,P264,Amy Ray,Daemon Records,Daemon Records artists such as,Amy Ray,Hello Rockview,50426
11957,12998,P264,Martin Sorrondeguy,Lengua Armada Discos,"Martin Sorrondeguy, which is represented by",Lengua Armada Discos,Modern Day Escape,50428
11958,12999,P264,Madlib,Stones Throw,Stones Throw artists such as,Madlib,Bix Beiderbecke,50441
11959,13000,P264,X&Y,Parlophone,X&Y's label is,Parlophone,Re-Constriction Records,50444


In [86]:
removed_ids = {}
removed_counterfacts = {}
for c_id in trex_causal_df['calibra_id']:
    if c_id not in trex_causal_subset['calibra_id'].values:
        fact_id = trex_causal_df[trex_causal_df['calibra_id'] == c_id]['fact_id'].values[0]
        counterfact = trex_causal_df[trex_causal_df['calibra_id'] == c_id]['false'].values[0]
        removed_ids[str(c_id)] = int(fact_id)
        if str(fact_id) in removed_counterfacts:
            removed_counterfacts[str(fact_id)].append(counterfact)
        else:
            removed_counterfacts[str(fact_id)] = [counterfact]

# did we remove as many rows as eq to the difference between the full calinet dataset row number and the unique count?
assert(len(removed_ids) == trex_causal_df.shape[0] - len(trex_causal_df['fact_id'].unique()))

In [87]:
# these are essentially the extra false things we can test against
# that are still worth keeping
c = 0
for k, v in removed_counterfacts.items():
    print(k, v)
    c+=1
    if c == 15:
        break

1 ['Upper Macungie Township', 'Anabar', 'Riau', 'Bologna']
2 ['Mpumalanga']
3 ['James City County, Virginia', 'Portneuf', 'Rockingham County, Virginia', 'Giridih', 'Canazei']
4 ['Sestriere', 'Nitra District', 'Acerra', 'Le Havre']
5 ['Ukrainians', 'First Czechoslovak Republic', 'Ziburu']
6 ['Oliver, British Columbia', 'Kapurthala', 'ASEAN']
7 ['Vinnytsia Oblast', 'Laveno-Mombello', 'Orbassano', 'Arnhem', 'Santa Cristina Gela']
8 ['North America', 'Mogilev Region', 'New Zealand/Aotearoa', 'Phasi Charoen']
9 ['Castile La Mancha', 'Chikballapur district', 'Brewster County']
10 ['Chaumont-Gistoux', 'Magadan Oblast']
11 ['Bulakan', 'East Flanders', 'Arenys de Munt']
12 ['First Czechoslovak Republic', 'South West Africa', 'Churchill, Manitoba']
13 ['Oak Park', 'Rabun County', 'Rio de Janeiro (RJ)', 'Lower Hutt']
14 ['Liberty Village', 'Civitacampomarano', 'Sorano']
15 ['Western region', 'Sheridan Hollow']


In [88]:
# drop extraneous calibra_id column 
trex_causal_subset.drop(['calibra_id'], axis=1, inplace=True)


In [89]:
# there are some fact_id's that only have 1 row
# since we did pull stuff out based on our left to right requirement
trex_causal_subset.shape

(11960, 7)

In [90]:
len(removed_counterfacts)

10563

In [91]:
full_falses = {}
for k, v in removed_counterfacts.items():
    subset_false = trex_causal_subset[trex_causal_subset['fact_id'] == int(k)].false.values[0]
    full_falses[k] = v
    full_falses[k].append(subset_false)

print(len(full_falses))

10563


In [92]:
for k, v in full_falses.items():
    print(k,v)
    break

1 ['Upper Macungie Township', 'Anabar', 'Riau', 'Bologna', 'Vadena']


In [93]:
def replace_false_column(fact_id, false_val, full_false_dict=full_falses):
    if str(fact_id) in full_false_dict:
        return full_false_dict[str(fact_id)]
    else:
        return [false_val]

In [94]:
replaced_falses = list(trex_causal_subset.apply(lambda x: replace_false_column(x.fact_id, x.false), axis=1))


In [95]:
len(replaced_falses)


11960

In [96]:
replaced_falses[:6]


[['Upper Macungie Township', 'Anabar', 'Riau', 'Bologna', 'Vadena'],
 ['Mpumalanga', 'Simbach'],
 ['James City County, Virginia',
  'Portneuf',
  'Rockingham County, Virginia',
  'Giridih',
  'Canazei',
  'Yixing'],
 ['Sestriere', 'Nitra District', 'Acerra', 'Le Havre', 'Montréal, Quebec'],
 ['Ukrainians', 'First Czechoslovak Republic', 'Ziburu', 'La Pampa'],
 ['Oliver, British Columbia', 'Kapurthala', 'ASEAN', 'Kodanad']]

In [97]:
trex_causal_subset['false'] = replaced_falses


In [98]:
trex_causal_subset.head()

Unnamed: 0,fact_id,relation,subject,object,stem,true,false
0,1,P47,Norfolk,Suffolk,Norfolk borders with,Suffolk,"[Upper Macungie Township, Anabar, Riau, Bologn..."
1,2,P47,Jordan,Israel,Jordan shares a common border with,Israel,"[Mpumalanga, Simbach]"
2,3,P47,Kenya,Ethiopia,Kenya shares border with,Ethiopia,"[James City County, Virginia, Portneuf, Rockin..."
3,4,P47,Egypt,Israel,Egypt shares its border with,Israel,"[Sestriere, Nitra District, Acerra, Le Havre, ..."
4,5,P47,Tanzania,Uganda,Tanzania borders with,Uganda,"[Ukrainians, First Czechoslovak Republic, Zibu..."


In [99]:
trex_causal_subset.tail()

Unnamed: 0,fact_id,relation,subject,object,stem,true,false
11955,12996,P264,Cody Wise,Interscope Records,The music label that is representing Cody Wise is,Interscope Records,"[Holy Records, Heads Up International, Disney,..."
11956,12997,P264,Amy Ray,Daemon Records,Daemon Records artists such as,Amy Ray,"[So So Def Recordings, Universal Music Japan, ..."
11957,12998,P264,Martin Sorrondeguy,Lengua Armada Discos,"Martin Sorrondeguy, which is represented by",Lengua Armada Discos,"[Frontiers Records, Barnaby Records, Aggro Ber..."
11958,12999,P264,Madlib,Stones Throw,Stones Throw artists such as,Madlib,"[Green Linnet Records, Mute records, Vee Jay R..."
11959,13000,P264,X&Y,Parlophone,X&Y's label is,Parlophone,"[Angular Recording Corporation, SPV GmbH, Abac..."


In [100]:
output_dict = {}
trex_list = trex_causal_subset.to_dict('records')
for i, entry in enumerate(trex_list):
    output_dict[i] = trex_list[i]

In [101]:
num_pairs = 0
for x, y in output_dict.items():
    output_dict[x] = y 
    output_dict[x]['false'] = list(set(y['false']))
    
    num_pairs += len(output_dict[x]['false'])

In [102]:
num_pairs

50386

In [103]:
# write out cleaned/formatted df
with open(
    f"../data/calinet_input_information.json", "w"
) as outfile:
    json.dump(output_dict, outfile)

In [104]:
# out of curiosity, which relation templates persist in the cleaned, 'causal friendly' set...
trex_causal_df['relation'].value_counts()

P495     4172
P138     3729
P264     3533
P1376    3509
P101     3279
P740     3246
P36      3215
P449     2794
P47      2265
P20      2046
P19      1760
P159     1729
P27      1717
P530     1506
P106     1497
P407     1492
P364     1457
P176     1277
P39      1268
P37      1000
P937      995
P178      995
P136      967
P463      758
P413      245
Name: relation, dtype: int64

## Cache ROME counterfact data

In [105]:
with open('../data/rome_counterfact_original/counterfact.json', 'r') as f:
    data_rome = json.load(f)

In [106]:
len(data_rome)

21919

In [107]:
data_rome_input_information = {}

for i in range(len(data_rome)):
    stem = data_rome[i]['requested_rewrite']['prompt'].replace('{}', data_rome[i]['requested_rewrite']['subject'])
    
    data_rome_input_information[str(i)] = {
        "stem": stem,
        "true": data_rome[i]['requested_rewrite']['target_true']['str'],
        "false": [data_rome[i]['requested_rewrite']['target_new']['str']],
        "case_id":  data_rome[i]['case_id']
    }

In [108]:
#data_rome_input_information

In [109]:
with open(
    f"../data/rome_counterfact_input_information.json", "w"
) as outfile:
    json.dump(data_rome_input_information, outfile)

## Combine the two datasets

In [126]:
with open('../data/calinet_input_information.json', 'r') as f:
    data_calinet = json.load(f)

with open('../data/rome_counterfact_input_information.json', 'r') as f:
    data_rome= json.load(f)


In [127]:
#data_calinet
#data_rome

mixed_itr = 0
mixed_df = {}

for x, y in data_calinet.items():
    y['dataset_original'] = 'calinet_input_information'
    mixed_df[str(mixed_itr)] = y

    mixed_itr+=1

for x, y in data_rome.items():
    y['dataset_original'] = 'rome_counterfact_input_information'
    mixed_df[str(mixed_itr)] = y
    mixed_itr+=1


In [128]:
itrs = 0
for x, y in mixed_df.items():
    itrs += 1

print(f'The number of items in mixed_df is {itrs}')

The number of items in mixed_df is 33879


In [129]:
# check for duplicate stem, fact, counterfact pairs across the dataset:

pairs_list = []
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        pairs_list.append(pairs)

print(f'The number of [stem + fact + counterfact] trios in mixed_df is {len(pairs_list)}')
# num duplicates 
num_dup = len(pairs_list) - len(np.unique(np.array(pairs_list)))
print(f'The number of duplicated [stem + fact + counterfact] trios in mixed_df is {num_dup}')


The number of [stem + fact + counterfact] trios in mixed_df is 72305
The number of duplicated [stem + fact + counterfact] trios in mixed_df is 2


In [130]:
# remove the 2 duplicates

pairs_list = []
dup_itr = 1
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        if pairs in pairs_list:
            print(x, dup_itr, pairs)
            dup_itr+=1
        pairs_list.append(pairs)


10107 1 ['Brazil ties diplomatic relations with Venezuela VEN']
25020 2 ['The original language of Paul Clifford was English Tamil']


In [131]:
# hard code remove the above
mixed_df['10107']['false'] = ['Bulgaria']
mixed_df['10107']
del mixed_df['25020']


In [132]:
# check for duplicates again

pairs_list = []
for x, y in mixed_df.items():
    for itr in range(len(y['false'])):
        pairs = [y['stem'] + ' ' + y['true'] + ' ' + y['false'][itr]]
        pairs_list.append(pairs)

print(f'The number of [stem + fact + counterfact] trios in mixed_df is {len(pairs_list)}')
# num duplicates 
num_dup = len(pairs_list) - len(np.unique(np.array(pairs_list)))
print(f'The number of duplicated [stem + fact + counterfact] trios in mixed_df is {num_dup}')

The number of [stem + fact + counterfact] trios in mixed_df is 72303
The number of duplicated [stem + fact + counterfact] trios in mixed_df is 0


In [133]:

with open(
    f"../data/calibragpt_full_input_information.json", "w"
) as outfile:
    json.dump(mixed_df, outfile)