In [40]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle
import seaborn as sns
import copy
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
from spectral_explain.dataloader import HotpotQA, Drop
from spectral_explain.utils import fourier_to_mobius, mobius_to_fourier
plt.rcParams['text.usetex'] = True
plt.rcParams.update({
    "text.usetex": True,              # Use TeX for text rendering
    "font.family": "serif",
    "hatch.color": "white"
})
mpl.rcParams.update(mpl.rcParamsDefault)

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Plotting functions
color_dict = {'qsft_hard_0': '#896190',
               'qsft_soft_0': '#896190',
               'lime_1': '#86a76c',
               'linear_1': '#37537c',
               'linear_2': '#697fa0',
               'linear_3': '#9aa9c4',
               'linear_4': '#cad6e9',
               'faith_banzhaf_1': '#37537c',
               'faith_banzhaf_2': '#697fa0',
               'faith_banzhaf_3': '#9aa9c4',
               'faith_banzhaf_4': '#cad6e9',
               'faith_shapley_1': '#96474A',
               'faith_shapley_2': '#B2746B',
               'faith_shapley_3': '#CFA08C',
               'faith_shapley_4': '#EBCDAD',
               'shapley_taylor_2': '#B67B80',
               'shapley_taylor_3': '#D7B0B7',
               'shapley_taylor_4': '#F7E4ED',
               'FSII_1': '#96474A',
               'FSII_2': '#B2746B',
               'FSII_3': '#CFA08C',
               'FSII_4': '#EBCDAD',
               'SV_1': '#96474A',
               }
name_dict = {'qsft_hard_0': 'SpectralExplain',
               'qsft_soft_0': 'SpectralExplain',
               'lime_1': 'LIME',
               'linear_1': 'Banzhaf',
               'linear_2': 'Faith-Banzhaf\n2nd Order',
               'linear_3': 'Faith-Banzhaf\n3rd Order',
               'linear_4': 'Faith-Banzhaf\n4th Order',
               'faith_shapley_1': 'Shapley',
               'faith_shapley_2': 'Faith-Shap\n2nd Order',
               'faith_shapley_3': 'Faith-Shap\n3rd Order',
               'faith_shapley_4': 'Faith-Shap\n4th Order',
               'shapley_taylor_2': 'Shapley-Taylor\n2nd Order',
               'shapley_taylor_3': 'Shapley-Taylor\n3rd Order',
               'shapley_taylor_4': 'Shapley-Taylor\n4th Order',
               }


# Faithfulness results

In [3]:
range_to_str = {range(8,15): '8-15', range(16,31): '16-31', range(32,63): '32-63', range(64,127): '64-127', range(128,255): '128-255', range(256,511): '256-511', range(512,1023): '512-1023'}
def get_group_results(groups,r2_results, b = 8, methods = ['qsft_soft_0', 'lime_1', 'linear_1', 'linear_2', 'linear_3', 'linear_4','FSII_4']):
    group_dict = {group: {method: [] for method in methods} for group in groups}
    for explicand in r2_results:
        if explicand['explicand']['id'] == '5a89dfcc55429946c8d6e9f0':
            continue
        n = explicand['explicand']['n']
        for group in groups:
            if n in group:
                for method in methods:
                    try:
                        group_dict[group][method].append(explicand[(method,b)]['r2'])
                    except:
                        group_dict[group][method].append(np.nan)
    
    group_dict_str = {}
    for group in group_dict:
        group_str = range_to_str[group]
        group_dict_str[group_str] = {}
        for method in group_dict[group]:
            group_dict_str[group_str][method] = np.nanmedian(group_dict[group][method])
    return group_dict_str


In [5]:
# groups_str = ['8-15', '16-31', '32-63', '64-127'] 
# groups = [range(int(group.split('-')[0]), int(group.split('-')[1])) for group in groups_str]
# group_methods = {groups[0]: ['lime_1', 'linear_1', 'linear_2', 'linear_3', 'linear_4', 'qsft_soft_0','lasso_1','lasso_2','lasso_3','lasso_4','FSII_4','STII_4','SV_1','faith_banzhaf_1','faith_banzhaf_2','faith_banzhaf_3','faith_banzhaf_4'] }
# group_methods[groups[1]] = ['lime_1', 'linear_1', 'linear_2', 'linear_3',  'qsft_soft_0','lasso_1','lasso_2','lasso_3', 'FSII_3','STII_3','SV_1','faith_banzhaf_1','faith_banzhaf_2','faith_banzhaf_3']
# group_methods[groups[2]] = ['lime_1', 'linear_1', 'linear_2',  'qsft_soft_0','lasso_1','lasso_2','FSII_2','STII_2','SV_1','faith_banzhaf_1','faith_banzhaf_2']
# group_methods[groups[3]] = ['lime_1', 'linear_1',  'qsft_soft_0','lasso_1','SV_1','faith_banzhaf_1']
all_methods = ['qsft_soft_0', 'lime_1','faith_banzhaf_4','FSII_4','faith_banzhaf_3','FSII_3','faith_banzhaf_2','FSII_2','faith_banzhaf_1', 'SV_1','STII_2','STII_3','STII_4']
all_results = {}
for task in ['hotpotqa', 'drop']:
    if task == 'hotpotqa':
        groups_str = ['8-15', '16-31', '32-63', '64-127'] 
        groups = [range(int(group.split('-')[0]), int(group.split('-')[1])) for group in groups_str]
    elif task == 'drop':
        groups_str = ['32-63', '64-127', '128-255', '256-511', '512-1023']
        groups = [range(int(group.split('-')[0]), int(group.split('-')[1])) for group in groups_str]
    for b in [4,6,8]:
        r2_results = pickle.load(open(f'results/{task}/r2_results.pkl', 'rb'))
        group_results = get_group_results(groups = groups, r2_results = r2_results, b = b, methods = all_methods)
        print(f'{task} {b} {group_results}')
        with open(f'processed_results/{task}_faith_{b}.pkl', 'wb') as f:
            pickle.dump(group_results, f)


hotpotqa 4 {'8-15': {'qsft_soft_0': 0.6223544706892019, 'lime_1': 0.3799364302639089, 'faith_banzhaf_4': 0.5946795758495655, 'FSII_4': 0.2751489133951058, 'faith_banzhaf_3': 0.725054334073008, 'FSII_3': -3.0014516118657557, 'faith_banzhaf_2': 0.6117254783275514, 'FSII_2': -20.273824183193046, 'faith_banzhaf_1': 0.3792738754219779, 'SV_1': -1.095332726490242, 'STII_2': nan, 'STII_3': nan, 'STII_4': -517.5866848953244}, '16-31': {'qsft_soft_0': 0.4979418192172792, 'lime_1': 0.38110185653452716, 'faith_banzhaf_4': nan, 'FSII_4': nan, 'faith_banzhaf_3': -154.3701622692681, 'FSII_3': -13.702211235894799, 'faith_banzhaf_2': 0.5788500801055343, 'FSII_2': -31.66674003902791, 'faith_banzhaf_1': 0.4884031989691646, 'SV_1': 0.21204896745407786, 'STII_2': -148.97834571570903, 'STII_3': nan, 'STII_4': nan}, '32-63': {'qsft_soft_0': 0.35082144889566536, 'lime_1': 0.43338240845250675, 'faith_banzhaf_4': nan, 'FSII_4': nan, 'faith_banzhaf_3': nan, 'FSII_3': nan, 'faith_banzhaf_2': 0.27724693134019707,

  group_dict_str[group_str][method] = np.nanmedian(group_dict[group][method])


# Subtraction results

## DROP

In [11]:
subtract_drop = pickle.load(open(f'processed_results/drop_subtract_results.pkl', 'rb'))
explicand_list = pickle.load(open(f'processed_results/drop_explicand_list.pkl', 'rb'))
subtract_res = {}
range_to_str = {range(8,15): '8-15', range(16,31): '16-31', range(32,63): '32-63', range(64,127): '64-127', range(128,255): '128-255', range(256,511): '256-511', range(512,1023): '512-1023'}
subtract_indices = {range(32,63): [], range(64,127): [], range(128,255): [], range(256,511): []}
n_list = [explicand['n'] for explicand in explicand_list]
for i, n in enumerate(n_list):
    if n in range(32,63):
        subtract_indices[range(32,63)].append(i)
    elif n in range(64,127):
        subtract_indices[range(64,127)].append(i)
    elif n in range(128,255):
        subtract_indices[range(128,255)].append(i)
    elif n in range(256,511):
        subtract_indices[range(256,511)].append(i)

subtract_res = {group: {} for group in subtract_indices}
subtract_res[range(32,63)]['qsft_soft_0'] = np.nanmean(subtract_drop['qsft_soft_0'][subtract_indices[range(32,63)],3:],axis=0)
subtract_res[range(64,127)]['qsft_soft_0'] = np.nanmean(subtract_drop['qsft_soft_0'][subtract_indices[range(64,127)],3:],axis=0)
subtract_res[range(128,255)]['qsft_soft_0'] = np.nanmean(subtract_drop['qsft_soft_0'][subtract_indices[range(128,255)],3:],axis=0)
subtract_res[range(256,511)]['qsft_soft_0'] = np.nanmean(subtract_drop['qsft_soft_0'][subtract_indices[range(256,511)],3:],axis=0)

subtract_res[range(32,63)]['lime_1'] = np.nanmean(subtract_drop['lime_1'][subtract_indices[range(32,63)],3:],axis=0)
subtract_res[range(64,127)]['lime_1'] = np.nanmean(subtract_drop['lime_1'][subtract_indices[range(64,127)],3:],axis=0)
subtract_res[range(128,255)]['lime_1'] = np.nanmean(subtract_drop['lime_1'][subtract_indices[range(128,255)],3:],axis=0)
subtract_res[range(256,511)]['lime_1'] = np.nanmean(subtract_drop['lime_1'][subtract_indices[range(256,511)],3:],axis=0)

subtract_res[range(32,63)]['FSII_2'] = np.nanmean(subtract_drop['FSII_2'][subtract_indices[range(32,63)],3:],axis=0)
subtract_res[range(32,63)]['STII_2'] = np.nanmean(subtract_drop['STII_2'][subtract_indices[range(32,63)],3:],axis=0)

subtract_res[range(64,127)]['SV_1'] = np.nanmean(subtract_drop['SV_1'][subtract_indices[range(64,127)],3:],axis=0)
subtract_res[range(128,255)]['SV_1'] = np.nanmean(subtract_drop['SV_1'][subtract_indices[range(128,255)],3:],axis=0)
subtract_res[range(256,511)]['SV_1'] = np.nanmean(subtract_drop['SV_1'][subtract_indices[range(256,511)],3:],axis=0)

subtract_res[range(32,63)]['faith_banzhaf_2'] = np.nanmean(subtract_drop['faith_banzhaf_2'][subtract_indices[range(32,63)],3:],axis=0)
subtract_res[range(32,63)]['faith_banzhaf_1'] = np.nanmean(subtract_drop['faith_banzhaf_1'][subtract_indices[range(32,63)],3:],axis=0)
subtract_res[range(64,127)]['faith_banzhaf_1'] = np.nanmean(subtract_drop['faith_banzhaf_1'][subtract_indices[range(64,127)],3:],axis=0)
subtract_res[range(128,255)]['faith_banzhaf_1'] = np.nanmean(subtract_drop['faith_banzhaf_1'][subtract_indices[range(128,255)],3:],axis=0)
subtract_res[range(256,511)]['faith_banzhaf_1'] = np.nanmean(subtract_drop['faith_banzhaf_1'][subtract_indices[range(256,511)],3:],axis=0)

with open(f'processed_results/drop_subtract_range(32,63).pkl', 'wb') as f:
    pickle.dump(subtract_res[range(32,63)], f)
with open(f'processed_results/drop_subtract_range(64,127).pkl', 'wb') as f:
    pickle.dump(subtract_res[range(64,127)], f)
with open(f'processed_results/drop_subtract_range(128,255).pkl', 'wb') as f:
    pickle.dump(subtract_res[range(128,255)], f)



## Hotpotqa


In [48]:

subtract_hotpot = pickle.load(open(f'processed_results/hotpotqa_subtract_results.pkl', 'rb'))
explicand_list = pickle.load(open(f'processed_results/hotpotqa_explicand_list.pkl', 'rb'))
subtract_res = {}
range_to_str = {range(8,15): '8-15', range(16,31): '16-31', range(32,63): '32-63', range(64,127): '64-127', range(128,255): '128-255', range(256,511): '256-511'}
subtract_indices = {range(8,15): [], range(16,31): [], range(32,63): [], range(64,127): [], range(128,255): [], range(256,511): []}
n_list = [explicand['n'] for explicand in explicand_list]
for i, n in enumerate(n_list):
    if n in range(8,15):
        subtract_indices[range(8,15)].append(i)
    elif n in range(16,31):
        subtract_indices[range(16,31)].append(i)
    elif n in range(32,63):
        subtract_indices[range(32,63)].append(i)
    elif n in range(64,127):
        subtract_indices[range(64,127)].append(i)
    elif n in range(128,255):
        subtract_indices[range(128,255)].append(i)
    elif n in range(256,511):
        subtract_indices[range(256,511)].append(i)

subtract_res = {group: {} for group in subtract_indices}
subtract_res[range(8,15)]['qsft_soft_0'] = np.nanmean(subtract_hotpot['qsft_soft_0'][subtract_indices[range(8,15)],:],axis=0)
subtract_res[range(16,31)]['qsft_soft_0'] = np.nanmean(subtract_hotpot['qsft_soft_0'][subtract_indices[range(16,31)],:],axis=0)[[1,2,3,6,7]]
subtract_res[range(32,63)]['qsft_soft_0'] = np.nanmean(subtract_hotpot['qsft_soft_0'][subtract_indices[range(32,63)],:],axis=0)[[1,2,3,4,6]]
subtract_res[range(64,127)]['qsft_soft_0'] = np.nanmean(subtract_hotpot['qsft_soft_0'][subtract_indices[range(64,127)],:],axis=0)

subtract_res[range(8,15)]['lime_1'] = np.nanmean(subtract_hotpot['lime_1'][subtract_indices[range(8,15)],:],axis=0)
subtract_res[range(16,31)]['lime_1'] = np.nanmean(subtract_hotpot['lime_1'][subtract_indices[range(16,31)],:],axis=0)[[1,2,3,4,5]]
subtract_res[range(32,63)]['lime_1'] = np.nanmean(subtract_hotpot['lime_1'][subtract_indices[range(32,63)],:],axis=0)[1:6]
subtract_res[range(64,127)]['lime_1'] = np.nanmean(subtract_hotpot['lime_1'][subtract_indices[range(64,127)],:],axis=0)


subtract_res[range(8,15)]['FSII_2'] = np.nanmean(subtract_hotpot['FSII_2'][subtract_indices[range(8,15)],:],axis=0)
subtract_res[range(16,31)]['FSII_2'] = np.nanmean(subtract_hotpot['FSII_2'][subtract_indices[range(16,31)],:],axis=0)[[2,1,3,4,5]]
#subtract_res[range(32,63)]['FSII_2'] = np.nanmean(subtract_hotpot['FSII_2'][subtract_indices[range(32,63)],:],axis=0)

subtract_res[range(8,15)]['STII_2'] = np.nanmean(subtract_hotpot['STII_2'][subtract_indices[range(8,15)],:],axis=0)
#subtract_res[range(16,31)]['STII_2'] = np.nanmean(subtract_hotpot['STII_2'][subtract_indices[range(16,31)],:],axis=0)[3:]
#subtract_res[range(32,63)]['STII_2'] = np.nanmean(subtract_hotpot['STII_2'][subtract_indices[range(32,63)],:],axis=0)


subtract_res[range(8,15)]['faith_banzhaf_2'] = np.nanmean(subtract_hotpot['faith_banzhaf_2'][subtract_indices[range(8,15)],:],axis=0)
subtract_res[range(16,31)]['faith_banzhaf_2'] = np.nanmean(subtract_hotpot['faith_banzhaf_2'][subtract_indices[range(16,31)],:],axis=0)[[1,2,3,4,6]]
#subtract_res[range(32,63)]['faith_banzhaf_2'] = np.nanmean(subtract_hotpot['faith_banzhaf_2'][subtract_indices[range(32,63)],:],axis=0)



subtract_res[range(8,15)]['SV_1'] = np.nanmean(subtract_hotpot['SV_1'][subtract_indices[range(8,15)], :],axis=0)
subtract_res[range(16,31)]['SV_1'] = np.nanmean(subtract_hotpot['SV_1'][subtract_indices[range(16,31)], :],axis=0)[3:]
subtract_res[range(32,63)]['SV_1'] = np.nanmean(subtract_hotpot['SV_1'][subtract_indices[range(32,63)], :],axis=0)[1:6]
subtract_res[range(64,127)]['SV_1'] = np.nanmean(subtract_hotpot['SV_1'][subtract_indices[range(64,127)], :],axis=0)


subtract_res[range(8,15)]['faith_banzhaf_1'] = np.nanmean(subtract_hotpot['faith_banzhaf_1'][subtract_indices[range(8,15)],:],axis=0)
subtract_res[range(16,31)]['faith_banzhaf_1'] = np.nanmean(subtract_hotpot['faith_banzhaf_1'][subtract_indices[range(16,31)],:],axis=0)[3:]
subtract_res[range(32,63)]['faith_banzhaf_1'] = np.nanmean(subtract_hotpot['faith_banzhaf_1'][subtract_indices[range(32,63)],:],axis=0)[1:6]
subtract_res[range(64,127)]['faith_banzhaf_1'] = np.nanmean(subtract_hotpot['faith_banzhaf_1'][subtract_indices[range(64,127)],:],axis=0)

subtract_res[range(32,63)]


  subtract_res[range(8,15)]['qsft_soft_0'] = np.nanmean(subtract_hotpot['qsft_soft_0'][subtract_indices[range(8,15)],:],axis=0)
  subtract_res[range(8,15)]['FSII_2'] = np.nanmean(subtract_hotpot['FSII_2'][subtract_indices[range(8,15)],:],axis=0)
  subtract_res[range(8,15)]['STII_2'] = np.nanmean(subtract_hotpot['STII_2'][subtract_indices[range(8,15)],:],axis=0)
  subtract_res[range(8,15)]['faith_banzhaf_2'] = np.nanmean(subtract_hotpot['faith_banzhaf_2'][subtract_indices[range(8,15)],:],axis=0)


{'qsft_soft_0': array([0.00367363, 0.00401882, 0.00447019, 0.00472578, 0.00501272]),
 'lime_1': array([0.00347584, 0.00410595, 0.00379943, 0.00427768, 0.00431478]),
 'SV_1': array([0.00380442, 0.00404869, 0.00398078, 0.00417747, 0.00443736]),
 'faith_banzhaf_1': array([0.0034651 , 0.00413931, 0.004006  , 0.00438358, 0.00412377])}

In [38]:
subtract_res[range(16,31)]
with open(f'processed_results/hotpotqa_subtract_range(32,63).pkl', 'wb') as f:
    pickle.dump(subtract_res[range(16,31)], f)


{'qsft_soft_0': array([0.00330579, 0.00537793, 0.00546174, 0.006698  , 0.00689953]),
 'lime_1': array([0.00336139, 0.00456465, 0.00521775, 0.00583585, 0.00602888]),
 'FSII_2': array([0.00291487, 0.00361294, 0.00392911, 0.0041657 , 0.00430151]),
 'faith_banzhaf_2': array([0.00309403, 0.00445877, 0.00501915, 0.00619695, 0.00695866]),
 'SV_1': array([0.00531369, 0.00544075, 0.00569884, 0.00582972, 0.00654453]),
 'faith_banzhaf_1': array([0.00574133, 0.00545796, 0.00616437, 0.00604176, 0.00626375])}

In [49]:
with open(f'processed_results/hotpotqa_subtract_range(64,127).pkl', 'wb') as f:
    pickle.dump(subtract_res[range(32,63)], f)
subtract_res[range(32,63)]

{'qsft_soft_0': array([0.00367363, 0.00401882, 0.00447019, 0.00472578, 0.00501272]),
 'lime_1': array([0.00347584, 0.00410595, 0.00379943, 0.00427768, 0.00431478]),
 'SV_1': array([0.00380442, 0.00404869, 0.00398078, 0.00417747, 0.00443736]),
 'faith_banzhaf_1': array([0.0034651 , 0.00413931, 0.004006  , 0.00438358, 0.00412377])}

# Hotpotqa recall 



In [12]:
hotpot_all = pickle.load(open(f'processed_results/reconstructions/reconstruction_results_hotpotqa.pkl', 'rb'))
hotpot_explicands = pickle.load(open(f'processed_results/hotpotqa_explicand_list.pkl', 'rb'))
explicand_list = []
indices = []
hotpot_recall = []
for i, explicand in enumerate(hotpot_explicands):
    if explicand['id'] == '5a89dfcc55429946c8d6e9f0':
        continue
    model_answer = explicand['model_answer']
    answer = explicand['answer']
    if model_answer.strip().lower() == answer.strip().lower():
        supporting_facts = [explicand['supporting_facts'][i][2] for i in range(len(explicand['supporting_facts']))]
        explicand_list.append((explicand, i, supporting_facts))
for explicand in explicand_list:
    hotpot_recall.append([hotpot_all[explicand[1]], explicand[2], explicand[0]['n']])

#hotpot_recall


In [89]:
def get_recall_per_method(reconstruction, ground_truth, recall_limit = None):
    if recall_limit is None:
        recall_limit = len(ground_truth)
    
    reconstruction = fourier_to_mobius(reconstruction)
    sorted_reconstruction = sorted(reconstruction.items(), key=lambda item: abs(item[1]), reverse=True)
    recall_scores = 0.0
    for interaction in sorted_reconstruction[:recall_limit+1]:
        if all(val == 0 for val in interaction[0]):
            continue
        
        indices = [i for i, val in enumerate(interaction[0]) if val == 1]
        print(indices, ground_truth)
        overlap = len(set(indices) & set(ground_truth))
        recall_scores += overlap / len(set(indices))
    return recall_scores / recall_limit

def get_recall_per_explicand(hotpot_recall, recall_limit = None, n_range = range(16,31), b = 4):
    method_res = {method: [] for method in methods}
    count = 0
    for row in hotpot_recall:
        if row[2] in n_range:
            for method in methods:
                print(method)
                recall = get_recall_per_method(row[0][(method, b)], row[1], recall_limit)
                method_res[method].append(recall)
        count += 1
    return method_res
methods = ['qsft_soft_0',]#'FSII_3', 'faith_banzhaf_3']#, 'FSII_3', 'faith_banzhaf_3'] 'FSII_4','faith_banzhaf_4'
recall_results = {method: [] for method in methods}
Bs = [6]
for b in Bs:
    all_recall = get_recall_per_explicand(hotpot_recall, recall_limit=10, n_range=range(8 ,128), b = b)
    for m in all_recall:
        recall_results[m].append(np.mean(all_recall[m]))
recall_results
with open(f'processed_results/hotpotqa_recall_results.pkl', 'wb') as f:
    pickle.dump(recall_results, f)
recall_results

qsft_soft_0
[3] [15, 4]
[3, 4, 5] [15, 4]
[0, 1, 3] [15, 4]
[8] [15, 4]
[21, 22, 23] [15, 4]
[23, 24] [15, 4]
[13] [15, 4]
[21] [15, 4]
[15] [15, 4]
[5] [15, 4]
qsft_soft_0
[7] [29, 0]
[5, 6] [29, 0]
[16] [29, 0]
[6] [29, 0]
[7, 16] [29, 0]
[5, 7, 9] [29, 0]
[5] [29, 0]
[4] [29, 0]
[6, 7] [29, 0]
[5, 7] [29, 0]
qsft_soft_0
[12] [12, 6]
[12, 17] [12, 6]
[0, 12] [12, 6]
[6] [12, 6]
[16, 17] [12, 6]
[8, 12] [12, 6]
[17] [12, 6]
[9, 11, 12] [12, 6]
[3] [12, 6]
[2, 12] [12, 6]
qsft_soft_0
[1] [20, 4]
[30] [20, 4]
[0] [20, 4]
[29, 30] [20, 4]
[15] [20, 4]
[4] [20, 4]
[29] [20, 4]
[18] [20, 4]
[3] [20, 4]
[20, 30] [20, 4]
qsft_soft_0
[5] [9, 5]
[0] [9, 5]
[7] [9, 5]
[0, 1, 5] [9, 5]
[2] [9, 5]
[2, 7] [9, 5]
[2, 3, 6] [9, 5]
[1, 2] [9, 5]
[0, 2] [9, 5]
[9] [9, 5]
qsft_soft_0
[8] [22, 8]
[22] [22, 8]
[21, 22] [22, 8]
[5, 8] [22, 8]
[7] [22, 8]
[12] [22, 8]
[27] [22, 8]
[28] [22, 8]
[26] [22, 8]
[5] [22, 8]
qsft_soft_0
[4] [18, 0]
[0, 32] [18, 0]
[8] [18, 0]
[0, 4] [18, 0]
[0, 1] [18, 0]
[31, 32

{'qsft_soft_0': [0.23838383838383836]}

# Look at Particular Examples of Interactions

In [39]:
original_prompt = hotpot_recall[7][0]['explicand']['original']
supporting_facts = hotpot_recall[7][0]['explicand']['supporting_facts']
supporting_facts

[('George Rainsford (actor)',
  'George Rainsford (born 31 July 1982) is an English actor, best known for his portrayal of Jimmy Wilson in the medical drama "Call the Midwife" and Ethan Hardy in "Casualty", for which he has been nominated for a Best Actor award in the 2017 TV Choice Awards.',
  4),
 ('TV Choice',
  'TV Choice is a British weekly TV listings magazine published by H. Bauer Publishing, the UK subsidiary of family-run German company Bauer Media Group.',
  36),
 ('TV Choice',
  ' It features weekly TV broadcast programming listings, running from Saturday to Friday, and goes on sale every Tuesday.',
  37)]

{'answer': 'Tuesday',
 'original': 'Title: Shabbir Jan \n Context:Shabbir Jan is a Pakistani television actor who has appeared in many drama serials, such as Wafa, "Makan", "Andata", Survival of a Woman, "Zindagi Dhoop Tum Ghana Saya", "Umrao Jaan", "Jangloos" and "Shab e Gham" and individual play and serials. He won three times PTV best actor award. He has been a nominee once for the Best Actor award in the Lux Style Award, 2002. He has worked television for 33 years and still working.\nTitle: George Rainsford (actor) \n Context:George Rainsford (born 31 July 1982) is an English actor, best known for his portrayal of Jimmy Wilson in the medical drama "Call the Midwife" and Ethan Hardy in "Casualty", for which he has been nominated for a Best Actor award in the 2017 TV Choice Awards.\nTitle: List of awards and nominations received by Vikram \n Context:Vikram is an Indian Tamil film actor. After making his cinematic debut in the 1990 film "En Kadhal Kanmani", he acted in a series of sma

In [35]:
hotpotqa_data = HotpotQA()
hotpotqa_data.load()

In [38]:
from datasets import load_dataset
dataset = load_dataset('hotpot_qa','distractor', 'validation', trust_remote_code = True)['validation']
for sample in dataset:
    if sample['id'] == '5abb03325542992ccd8e7eae':
        print(sample)
        break


{'id': '5abb03325542992ccd8e7eae', 'question': 'The magazine that nominated George Rainsford for their Best Actor award in 2017 comes out every week on what day of the week?', 'answer': 'Tuesday', 'type': 'bridge', 'level': 'hard', 'supporting_facts': {'title': ['George Rainsford (actor)', 'TV Choice', 'TV Choice'], 'sent_id': [0, 0, 1]}, 'context': {'title': ['Shabbir Jan', 'George Rainsford (actor)', 'List of awards and nominations received by Vikram', 'Ethan Hardy', 'NFL regular season', 'TV Choice', 'Avtaar', 'Jean-Louis Trintignant', 'Peter Finch', 'Caleb Knight'], 'sentences': [['Shabbir Jan is a Pakistani television actor who has appeared in many drama serials, such as Wafa, "Makan", "Andata", Survival of a Woman, "Zindagi Dhoop Tum Ghana Saya", "Umrao Jaan", "Jangloos" and "Shab e Gham" and individual play and serials.', ' He won three times PTV best actor award.', ' He has been a nominee once for the Best Actor award in the Lux Style Award, 2002.', ' He has worked television f

# Sparsity 

In [92]:
hotpot_reconstructions = pickle.load(open(f'processed_results/reconstructions/reconstruction_results_hotpotqa.pkl', 'rb'))

In [96]:
sparsity_dict = {range(8,15): [], range(16,31): [], range(32,63): [], range(64,127): []}
for explicand in hotpot_reconstructions:
    n = explicand['explicand']['n']
    num_nonzero = len(explicand[('qsft_soft_0', 8)])
    sparsity_ratio = num_nonzero / 2**n
    for group in sparsity_dict:
        if n in group:
            sparsity_dict[group].append(sparsity_ratio)
sparsity_dict


{(0, 0, 0, 0, 0, 0, 0, 0): -11.475307761460499, (0, 0, 0, 0, 0, 0, 0, 1): -0.0018415212089166744, (0, 0, 0, 0, 0, 0, 1, 0): 0.0019088703065790469, (0, 0, 0, 0, 0, 0, 1, 1): 0.005708905051506008, (0, 0, 0, 0, 0, 1, 0, 0): 0.007602650706758141, (0, 0, 0, 0, 0, 1, 0, 1): 0.0022977842836553464, (0, 0, 0, 0, 0, 1, 1, 0): -0.0005969276080577401, (0, 0, 0, 0, 0, 1, 1, 1): -0.005159633534276509, (0, 0, 0, 0, 1, 0, 0, 0): 0.005100921976918471, (0, 0, 0, 0, 1, 0, 0, 1): 3.473973447398748e-05, (0, 0, 0, 0, 1, 0, 1, 0): -0.0003459180679783458, (0, 0, 0, 0, 1, 0, 1, 1): -0.0009769042135303607, (0, 0, 0, 0, 1, 1, 0, 0): -0.001817127382309991, (0, 0, 0, 0, 1, 1, 0, 1): -0.0018255334380228305, (0, 0, 0, 0, 1, 1, 1, 0): -0.003258978185840533, (0, 0, 0, 0, 1, 1, 1, 1): 0.00047506382725259755, (0, 0, 0, 1, 0, 0, 0, 0): 0.008793765324298874, (0, 0, 0, 1, 0, 0, 0, 1): 0.0059440816312417155, (0, 0, 0, 1, 0, 0, 1, 0): 0.0011550008111953503, (0, 0, 0, 1, 0, 0, 1, 1): 0.0016462051880807849, (0, 0, 0, 1, 0, 1, 