# Counterfactual Search

0. Networks Input
1. Black-box function
2. Counterfactual Search: DS - Dataset Search
3. Counterfactual Search: OFS - Oblivious Forward Search
3. Counterfactual Search: DFS - Data-driven Forward Search
4. Counterfactual Search: OBS - Oblivious Backward Search
4. Counterfactual Search: DBS - Data-driven Backward Search

In [None]:
import os
import numpy as np
import csv
import matplotlib
import matplotlib.pyplot as plt

## 0. Networks Input


In [None]:
#import graph
data = {}
path = '../data/AUTISM'
path1 = path+'/asd/'
path2 = path+'/td/'
paths = [path2,path1]
label = 0
for path in paths:
    for filename in os.listdir(path):
        if 'DS_Store' not in filename:
            with open(path+filename, 'r') as f:
                if filename[-3:]=='csv':
                    l = [[int(float(num)) for num in line.split(',')] for line in f] # if .txt
                else:
                    l = [[int(num) for num in line.split(' ')] for line in f] # if .csv
                name = filename.split('.')[0]
                data[name] = (label,np.array(l))
    label +=1

## 1. Black-box function

In [None]:
def sub_graph(g,v_sub):
    '''To create the sub graph of 'g' from the list of nodes in 'v_sub'.
    '''
    g_sub = np.copy(g)
    #l_1 = [el for el in v_sub]
    l_1 = [el for el in v_sub]
    g_sub = g_sub[np.ix_(l_1,l_1)]
    return g_sub

def feature_extraction(g):
    ''' The classification funcion for the graph 'g'
    '''
    # Sub-graphs
    td_asd = [65, 70, 99, 80, 69, 6, 7, 8, 9, 13, 77, 45, 16, 81, 78, 92, 56, 57, 60, 93, 63]
    asd_td = [0, 36, 37, 38, 81, 40, 41, 74, 75, 76, 70, 72, 114, 20, 21, 73, 90, 28, 29]

    # Induced sub-graphs
    g_td_asd = sub_graph(g,td_asd)
    g_asd_td = sub_graph(g,asd_td)

    # Coefficients
    a = sum([sum(i) for i in g_td_asd])/2
    b = sum([sum(i) for i in g_asd_td])/2
    return a,b

### 1.1 Visualize the results

In [None]:
for k,v in data.items():
    data[k] = (v[0],v[1],feature_extraction(v[1]))

In [None]:
def oracle(g):
    f = feature_extraction(g)
    # Apply the rule
    w_1 = -0.181742414867891
    w_2 = 0.04327200353999672
    bk = 3.2844839747590915 
    x = bk + w_1*f[0] + w_2*f[1]
    # Classify
    if x>0:
        return 1#,a,b #'ASD'
    else:
        return 0#,a,b#'TD'

In [None]:
#name_o = 'KKI_0050777'
info = []
results = {'pred':[],'true':[]}
for name,graph in data.items():
    g_o = data[name][1] # Original graph
    y = data[name][0]
    a,b = data[name][2]
    y_hat= oracle(g_o)
    #print('The graph {} with lable {} is classified as {}'.format(name,y,y_hat))
    results['pred'].append(y_hat)
    results['true'].append(y)
    info.append((y_hat,y,a,b))

In [None]:
w_1 = -0.181742414867891
w_2 = 0.04327200353999672
bk = 3.2844839747590915 
l_0_a_tot = [el[2] for el in info if el[1]==0]
l_0_b_tot = [el[3] for el in info if el[1]==0]
l_1_a_tot = [el[2] for el in info if el[1]==1]
l_1_b_tot = [el[3] for el in info if el[1]==1]
l_a_tot = [el[2] for el in info]
l_b_tot = [(bk+(w_1*el[2]))/(-w_2) for el in info]

In [None]:
# Plot 

plt.figure(figsize=(15,9))
plt.ylim(0.0, 80.0)
plt.xlim(00.0, 60.0)

plt.plot(l_0_a_tot,l_0_b_tot,'bo',label="TD")
plt.plot(l_1_a_tot,l_1_b_tot,'ro',label="ASD")
plt.plot(l_a_tot,l_b_tot,'g-')

plt.title('Scatter Plot')
plt.ylabel('ASD_TD')
plt.xlabel('TD_ASD')
plt.legend()
#plt.savefig('Scatter_plot.png')
plt.show()

In [None]:
from sklearn.metrics import confusion_matrix,accuracy_score
tn, fp, fn, tp = confusion_matrix(results['true'],results['pred']).ravel()
print('Results:\n- {} TP;\n- {} TN;\n- {} FP;\n- {} FN.'.format(tp,tn,fp,fn))
accuracy = accuracy_score(results['true'],results['pred'])
print('Accuracy = {}'.format(accuracy))

## 2. Counterfactual Search: DS - Dataset Search

In [None]:
def tot_edges(g):
    '''Returns the total number of edges for undirected graphs
    '''
    return sum([sum(el) for el in g])/2

def edit_distance(g_1,g_2):
    '''
    '''
    return tot_edges(abs(g_1-g_2))

In [None]:
def dataset_distance(data,g_name):
    ''' Sort of the datasets graph (classified in the counterfactual class y_bar) 
        by edit distance.
    '''
    l_dist = []
    g = data[g_name][1] # original graph
    y_hat = oracle(g)
    y_bar = abs(1-y_hat)
    l=1
    for name,v in data.items():
        y_i = oracle(v[1])
        l += 1
        if name != g_name and y_i==y_bar:
            d = edit_distance(g,v[1])
            l_dist.append((name,d))
    return sorted(l_dist, key=lambda tup: tup[1]),l

In [None]:
%%time
results_dataset = {}
dataset_d = {}
k = 0
for gname,v in data.items():
    print(k, end=' - ')
    k+=1
    g = v[1] # Original graph
    y_hat = oracle(g)
    dim = len(g)
    y_bar = abs(1-y_hat) # counterfactual class
    d_bar,l = dataset_distance(data,gname)
    name_c, d = d_bar[0]
    g_c = v[1]
    results_dataset[gname] = (d,y_hat,l)
    dataset_d[gname] = name_c


In [None]:
edit_list = [el[0] for el in results_dataset.values()]

In [None]:
print(min(edit_list))
print(np.quantile(edit_list,0.25))
print(np.quantile(edit_list,0.50))
print(np.quantile(edit_list,0.75))
print(max(edit_list))

## 3. Counterfactual Search: OFS - Oblivious Forward Search

In [None]:
import random
def bernoulli(p):
    ''' p is the probability of removing an edge.
    '''
    return True if random.random() < p else False

        

def forward_greedy(g_o,y_bar,k=5,lambda_g=2000,p_0=0.5):
    '''
    '''
    dim = len(g_o)
    l=0
    
    # Candidate counterfactual
    g_c = np.copy(g_o)
    r = abs(1-y_bar)

    # Create add and remove sets of edges
    g_add = []
    g_rem = []
    for i in range(dim):
        for j in range(i,dim):
            if i!=j:
                if g_c[i][j]>0.5:
                    g_rem.append((i,j))
                else:
                    g_add.append((i,j))
    # randomize and remove duplicate
    random.shuffle(g_add)
    random.shuffle(g_rem)
    
    # Start the search
    while(l<lambda_g):
        ki=0
        while(ki<k):
            if bernoulli(p_0):
                # remove
                i,j = g_rem.pop(0)
                g_c[i][j]=0
                g_c[j][i]=0
                g_add.append((i,j))
                #random.shuffle(g_add)
            else:
                # add
                i,j = g_add.pop(0)
                g_c[i][j]=1
                g_c[j][i]=1
                g_rem.append((i,j))
                #random.shuffle(g_rem)
            ki+=1
        ki=0
        r = oracle(g_c)
        l += 1
        if r==y_bar:
            #print('- A counterfactual is found!')
            d = edit_distance(g_o,g_c)
            return d,g_c,l
        if len(g_rem)<1:
            print('no more remove')
    return 0,g_o,l

In [None]:
%%time
lambda_g = 2000
results_forward = {}
for q in range(5):
    print(q)
    k = 0
    results_forward_i = {}
    for gname,v in data.items():
        print(k, end=' - ')
        #print('{}-- Dataset Search for {} graph ---'.format(k,gname))
        k+=1
        g = v[1] # Original graph
        y_hat = oracle(g)
        #
        d_final,g_c_final,lambda_final = forward_greedy(g,abs(1-y_hat))
        results_forward_i[gname] = [d_final,g_c_final,lambda_final]
    results_forward[q] = results_forward_i

In [None]:
results_forward = {i:results_forward[i] for i in range(5)}

In [None]:
qq= 5
edit_list = []
edit_std = []
lambda_list = []
lambda_std = []
not_found = []
for gname,v in data.items():
    edit_i = [results_forward[q][gname][0] for q in range(qq) if results_forward[q][gname][0]!=0]
    if len(edit_i)>0:
        edit_list.append(sum(edit_i)/len(edit_i))
        not_found.append(qq-len(edit_i))
        lambda_i = [results_forward[q][gname][2] for q in range(qq) if results_forward[q][gname][0]!=0]
        lambda_list.append(sum(lambda_i)/len(lambda_i))
        lambda_std.append(round(np.std(lambda_i),1))
        edit_std.append(round(np.std(edit_i),1))
    else:
        #edit_list.append(2000)
        not_found.append(qq)
        #lambda_list.append(2000)

In [None]:
print('Edit Distance: Average')
print('& {} & {} & {} & {} & {}'.format(min(edit_list),np.quantile(edit_list,0.25),np.quantile(edit_list,0.50),
                                        np.quantile(edit_list,0.75),max(edit_list)))

In [None]:
print('Lambda: Average')
print('& {} & {} & {} & {} & {}'.format(min(lambda_list),np.quantile(lambda_list,0.25),np.quantile(lambda_list,0.50),
                                        np.quantile(lambda_list,0.75),max(lambda_list)))

In [None]:
print('Edit Distance: Standard Deviation')
print('& {} & {} & {} & {} & {}'.format(min(edit_std),np.quantile(edit_std,0.25),np.quantile(edit_std,0.50),
                                        np.quantile(edit_std,0.75),max(edit_std)))

In [None]:
print('Lambda: Standard Deviation')
print('& {} & {} & {} & {} & {}'.format(min(lambda_std),np.quantile(lambda_std,0.25),np.quantile(lambda_std,0.50),
                                        np.quantile(lambda_std,0.75),max(lambda_std)))

In [None]:
print('Not Found = ',not_found)
print('Total Not Found = ',sum(not_found))
print('Avg Not Found = ',sum(not_found)/qq)

## 3. Counterfactual Search: DFS - Data-driven Forward Search

In [None]:
# Nodes frequency
#create the two matices as the count of the frequency of each edge to be in a graph of the dataset
dim_g = 116
g_0 = np.zeros((dim_g,dim_g))
g_1 = np.zeros((dim_g,dim_g))
for k,v in data.items():
    g = v[1]
    y_hat = oracle(g)
    if y_hat==0:
        g_0 = np.add(g_0,g)
    else:
        g_1 = np.add(g_1,g)
g_01 = g_0-g_1
g_10 = g_1-g_0

In [None]:
min_01 = g_01.min()
max_01 = g_01.max()
g01 = np.ones((dim_g,dim_g))+(g_01-min_01)/(max_01-min_01)
min_10 = g_10.min()
max_10 = g_10.max()
g10 = np.ones((dim_g,dim_g))+(g_10-min_10)/(max_10-min_10)
prob_initial = {0:g01/g01.sum(),1:g10/g10.sum()}
g00 = np.ones((dim_g,dim_g))
uniform_initial_prop = {0:g00/g00.sum(),1:g00/g00.sum()}

In [None]:
def DFS_select(g,edges,y_bar,ki,edges_prob,p_0=0.5):
    '''
    '''
    edges_prob_rem = np.array([])
    edges_prob_add = np.array([])
    edges_add = []
    edges_rem = []
    e = []
    dim = len(g)
    for i in range(dim):
        for j in range(dim):
            if (i,j) not in edges:
                if g[i][j]>0:
                    edges_prob_rem = np.append(edges_prob_rem,edges_prob[1-y_bar][i][j])
                    edges_rem.append((i,j))
                else:
                    edges_prob_add = np.append(edges_prob_add,edges_prob[y_bar][i][j])
                    edges_add.append((i,j))
    edges_prob_add = edges_prob_add/edges_prob_add.sum()
    edges_prob_rem = edges_prob_rem/edges_prob_rem.sum()
    #print('-- ',len(edges_rem),len(edges_add),len(edges))
    edges_i = []
    kii=0
    while(kii<ki):
        kii+=1
        if bernoulli(p_0) and len(edges_add)>0:
            #add
            n = np.random.choice(range(len(edges_add)), size=1, p=edges_prob_add)[0]
            i,j = edges_add[n]
            g[i][j]=1
            g[j][i]=1
        elif len(edges_rem)>0:
            #remove
            n = np.random.choice(range(len(edges_rem)), size=1, p=edges_prob_rem)[0]
            i,j = edges_rem[n]
            g[i][j]=0
            g[j][i]=0
        edges.append((i,j))
    return g,edges

In [None]:
def DFS(g,y_bar,edges_prob,k=10,l_max=2000):
    '''
    '''
    info = []
    gc = np.copy(g)
    d = edit_distance(g,gc)
    li=0
    edges=[]
    while(li<l_max):
        gc,edges = DFS_select(gc,edges,y_bar,k,edges_prob,)
        r = oracle(gc)
        #print(li,len(edges),r)
        li += 1
        if r==y_bar:
            #print('- A counterfactual is found!')
            d = edit_distance(g_o,gc)
            return d,gc,l
    return 0,gc,l

In [None]:
%%time
lambda_g = 2000
results_prob = {}
for q in range(3):
    print(q)
    k = 0
    results_prob_i = {}
    for gname,v in data.items():
        print(k, end=' ')
        #print('{}-- Dataset Search for {} graph ---'.format(k,gname))
        k+=1
        g = v[1] # Original graph
        y_hat = oracle(g)
        y_bar = abs(1-y_hat)
        #
        d_final,g_c_final,lambda_final = DFS(g,y_bar,prob_initial)
#        d_final,g_c_final,lambda_final = forward_probabilistic(g,y_bar,edges_prob)
        results_prob_i[gname] = [d_final,g_c_final,lambda_final]
        print('->({})'.format(d_final),end=' - ')
    results_prob[q] = results_prob_i

### DFS - other scores for edges

In [None]:
dim_g = 116
g_0 = np.ones((dim_g,dim_g))
g_1 = np.ones((dim_g,dim_g))
for k,v in data.items():
    g = v[1]
    y_hat = oracle(g)
    if y_hat==0:
        g_0 = np.add(g_0,g)
    else:
        g_1 = np.add(g_1,g)

In [None]:
prob_initial = {0:g_0/g_0.sum(),1:g_1/g_1.sum()}

In [None]:
## Edges probabilities
# Nodes Class 0
nodes_0_sum = np.array([sum(el) for el in g_0]).sum()
edges_0 = g_0.ravel()/nodes_0_sum

# Nodes Class 1
nodes_1_sum = np.array([sum(el) for el in g_1]).sum()
edges_1 = g_1.ravel()/nodes_1_sum

#edges = np.array(edges)
edges_prob = {0:edges_0, 1:edges_1}

In [None]:
# import random
def bernoulli(p):
    ''' p is the probability of removing an edge.
    '''
    return True if random.random() < p else False


def forward_probabilistic(g_o,y_bar,edges_prob,lambda_g=2000,p_0=0.5):
    '''
    '''
    dim = len(g_o)
    edges = []
    e = []
    k = 0
    for i in range(dim_g):
        for j in range(dim_g):
            edges.append((i,j))
            e.append(k)
            k+=1
    l=0
    
    # Candidate counterfactual
    g_c = np.copy(g_o)
    r = abs(1-y_bar)
    
    # Start the search
    while(l<lambda_g):
        if bernoulli(p_0):
            # remove
            n = np.random.choice(e, size=1, p=edges_prob[abs(y_bar-1)])[0]
            i,j = edges[n]
            g_c[i][j]=0
            g_c[j][i]=0
        else:
            # add
            n = np.random.choice(e, size=1, p=edges_prob[y_bar])[0]
            i,j = edges[n]
            g_c[i][j]=1
            g_c[j][i]=1
        r = oracle(g_c)
        l += 1
        if r==y_bar:
            #print('- A counterfactual is found!')
            d = edit_distance(g_o,g_c)
            return d,g_c,l
    return 0,g_o,l

In [None]:
%%time
lambda_g = 2000
results_prob_0 = {}
for q in range(5):
    print(q)
    k = 0
    results_prob_i = {}
    for gname,v in data.items():
        print(k, end=' - ')
        #print('{}-- Dataset Search for {} graph ---'.format(k,gname))
        k+=1
        g = v[1] # Original graph
        y_hat = oracle(g)
        y_bar = abs(1-y_hat)
        dim = len(g)
        #
        d_final,g_c_final,lambda_final = forward_probabilistic(g,y_bar,edges_prob)
        #d_final,g_c_final,lambda_final = forward_greedy(g,abs(1-y_hat))
        results_prob_i[gname] = [d_final,g_c_final,lambda_final]
    results_prob_0[q] = results_prob_i

## 4. Counterfactual Search: OBS - Oblivious Backward Search

In [None]:
def get_change_list(g1,g2):
    edges = []
    g_diff = abs(g1-g2)
    dim_g = len(g1)
    for i in range(dim_g):
        for j in range(i,dim_g):
            if g_diff[i][j]==1:
                edges.append((i,j))
    return edges

In [None]:
def bb(g,gc1,y_bar,k=5,l_max=2000):
    '''
    '''
    gc = np.copy(gc1)
    edges = get_change_list(g,gc)
    d = edit_distance(g,gc)
    random.shuffle(edges)
    li=0
    while(li<l_max and len(edges)>0 and d>1):
        ki = min(k,len(edges))
        gci = np.copy(gc)
        edges_i = [edges.pop(0) for i in range(ki)]
        for i,j in edges_i:
            if gci[i][j]>0.5:
                gci[i][j] = 0
                gci[j][i] = 0
            else:
                gci[i][j] = 1
                gci[j][i] = 1
        r = oracle(gci)
        li += 1
        if r==y_bar:
            gc = np.copy(gci)
            d = edit_distance(g,gc)
            #print('ok --> ',r,d,l,k)
            info.append((r,d,li,ki))
            k+=1
        else:
            d = edit_distance(g,gc)
            info.append((r,d,li,ki))
            if k>1:
                k-=1
                edges = edges + edges_i
    return gc,edit_distance(g,gc),li,info

In [None]:
%%time
lambda_g = 2000
k = 10
info_k_dist = {}
max_m = len(data.keys())
r_bb = {}
for q in range(5):
    print(q)
    m = 1
    r_bb_i = {}
    for oname,v in data.items():
        print('{}/{}'.format(m,max_m), end=' - ')
        m+=1
        g = v[1]
        y_hat = oracle(g)
        gc_name = dataset_d[oname]
        gc = data[gc_name][1]
        d_initial = edit_distance(g,gc)
        gc2,d,l,info = bb(g,gc,abs(1-y_hat))
#        info = sum(info.values(), [])
        d_final = edit_distance(g,gc2)
        r_bb_i[oname] = [d_final,l,y_hat,d_initial,info,gc2]
    r_bb[q] = r_bb_i

In [None]:
# d_final,l,y_hat,d_initial,info,gc2
qq = len(r_bb)
ed = []
la = []
not_found = []
for name in list(r_bb[0].keys()):
    ed.append(np.array([r_bb[q][name][0] for q in range(qq)]))
    la.append(np.array([r_bb[q][name][1] for q in range(qq)]))
    not_found.append(np.array([name for q in range(qq) if r_bb[q][name][0]<1]))
ed_avstd = [round(np.std(el),2) for el in ed]
la_avstd = [round(np.std(el),2) for el in la]

In [None]:
# d_final,l,y_hat,d_initial,info,gc2
qq = len(r_bb)
ed_avg = []
la_avg = []
for name in list(r_bb[0].keys()):
    ed_avg.append(sum([r_bb[q][name][0] for q in range(qq)])/qq)
    la_avg.append(sum([r_bb[q][name][1] for q in range(qq)])/qq)

In [None]:
print('Edit Distance: Average')
print('& {} & {} & {} & {} & {}'.format(min(ed_avg),np.quantile(ed_avg,0.25),np.quantile(ed_avg,0.50),
                                        np.quantile(ed_avg,0.75),max(ed_avg)))

In [None]:
print('Lambda: Avg')
print('& {} & {} & {} & {} & {}'.format(min(la_avg),np.quantile(la_avg,0.25),np.quantile(la_avg,0.50),
                                        np.quantile(la_avg,0.75),max(la_avg)))

In [None]:
print('Edit Distance: Standard Deviation')
print('& {} & {} & {} & {} & {}'.format(min(ed_avstd),np.quantile(ed_avstd,0.25),np.quantile(ed_avstd,0.50),
                                        np.quantile(ed_avstd,0.75),max(ed_avstd)))

In [None]:
print('Lambda: Standard Deviation')
print('& {} & {} & {} & {} & {}'.format(min(la_avstd),np.quantile(la_avstd,0.25),np.quantile(la_avstd,0.50),
                                        np.quantile(la_avstd,0.75),max(la_avstd)))

## 4. Counterfactual Search: DBS - Data-driven Backward Search

In [None]:
#create the two matices as the count of the frequency of each edge to be in a graph of the dataset
dim_g = 116
g_0 = np.zeros((dim_g,dim_g))
g_1 = np.zeros((dim_g,dim_g))
for k,v in data.items():
    g = v[1]
    y_hat = oracle(g)
    if y_hat==0:
        g_0 = np.add(g_0,g)
    else:
        g_1 = np.add(g_1,g)
g_01 = g_0-g_1
g_10 = g_1-g_0

In [None]:
min_01 = g_01.min()
max_01 = g_01.max()
g01 = np.ones((dim_g,dim_g))+(g_01-min_01)/(max_01-min_01)

In [None]:
min_10 = g_10.min()
max_10 = g_10.max()
g10 = np.ones((dim_g,dim_g))+(g_10-min_10)/(max_10-min_10)

In [None]:
prob_initial = {0:g01/g01.sum(),1:g10/g10.sum()}

In [None]:
prob_initial[0].min(),prob_initial[1].min()

In [None]:
g00 = np.ones((dim_g,dim_g))
uniform_initial_prop = {0:g00/g00.sum(),1:g00/g00.sum()}

In [None]:
def get_prob_edges(gc,edges,y_bar,ki,edges_prob,p_0=0.5):
    '''
    '''
    gci = np.copy(gc)
    edges_prob_rem = np.array([])
    edges_prob_add = np.array([])
    edges_add = []
    edges_rem = []
    e = []
    gd = g - gci
    for e in edges:
        i,j = e
        if gc[i][j]>0:
            edges_prob_rem = np.append(edges_prob_rem,edges_prob[1-y_bar][i][j])
            edges_rem.append((i,j))
        else:
            edges_prob_add = np.append(edges_prob_add,edges_prob[y_bar][i][j])
            edges_add.append((i,j))
    edges_prob_add = edges_prob_add/edges_prob_add.sum()
    edges_prob_rem = edges_prob_rem/edges_prob_rem.sum()
    #print('-- ',len(edges_rem),len(edges_add),len(edges))
    edges_i = []
    kii=0
    while(kii<ki):
        kii+=1
        if bernoulli(p_0) and len(edges_add)>0:
            #add
            n = np.random.choice(range(len(edges_add)), size=1, p=edges_prob_add)[0]
            i,j = edges_add[n]
            gci[i][j]=1
            gci[j][i]=1
        elif len(edges_rem)>0:
            #remove
            n = np.random.choice(range(len(edges_rem)), size=1, p=edges_prob_rem)[0]
            i,j = edges_rem[n]
            gci[i][j]=0
            gci[j][i]=0
        edges_i.append((i,j))
    new_edges = edges_add+edges_rem
    return gci,new_edges

In [None]:
def bb_prob_2(g,gc1,y_bar,edges_prob,k=5,l_max=2000):
    '''
    '''
    info = []
    gc = np.copy(gc1)
    edges = get_change_list(g,gc)
    d = edit_distance(g,gc)
    li=0
    while(li<l_max and len(edges)>0 and d>1):
        ki = min(k,len(edges))
        #gci,edges,edges_i = get_prob_edges(g,edges,y_bar,k,edges_prob)
        gci,new_edges = get_prob_edges(gc,edges,y_bar,k,edges_prob)
        r = oracle(gci)
        li += 1
        if r==y_bar and edit_distance(gci,gc)>0:
            gc = np.copy(gci)
            d = edit_distance(g,gc)
            edges = get_change_list(g,gc)
            #print('ok --> ',r,d,li,k)
            info.append((r,d,li,ki))
            k+=1
        else:
            #print('no --> ',r,d,li,k)
            d = edit_distance(g,gc)
            info.append((r,d,li,ki))
            if k>1:
                k-=1
            else:
                edges.remove(new_edges[0])
    return gc,edit_distance(g,gc),li,info

In [None]:
%%time
lambda_g = 2000
k = 2
info_k_dist = {}
max_m = len(data.keys())
r_bb_prop = {}
for q in range(5):
    print(q)
    m = 1
    r_bb_i = {}
    for oname,v in data.items():
        print('{}/{}'.format(m,max_m), end=' - ')
        m+=1
        g = v[1]
        y_hat = oracle(g)
        gc_name = dataset_d[oname]
        gc = data[gc_name][1]
        d_initial = edit_distance(g,gc)
        gc2,d,l,info = bb_prob_2(g,gc,abs(1-y_hat),prob_initial)
        #info = sum(info.values(), [])
        d_final = edit_distance(g,gc2)
        r_bb_i[oname] = [d_final,l,y_hat,d_initial,info,gc2]
    r_bb_prop[q] = r_bb_i

In [None]:
# d_final,l,y_hat,d_initial,info,gc2
qq = len(r_bb_prop)
ed = []
la = []
not_found = []
for name in list(r_bb_prop[0].keys()):
    ed.append(np.array([r_bb_prop[q][name][0] for q in range(qq)]))
    la.append(np.array([r_bb_prop[q][name][1] for q in range(qq)]))
    not_found.append(np.array([name for q in range(qq) if r_bb_prop[q][name][0]<1]))
ed_avstd = [round(np.std(el),1) for el in ed]
la_avstd = [round(np.std(el),1) for el in la]

In [None]:
# d_final,l,y_hat,d_initial,info,gc2
qq = len(r_bb_prop)
ed_avg = []
la_avg = []
for name in list(r_bb_prop[0].keys()):
    ed_avg.append(sum([r_bb_prop[q][name][0] for q in range(qq)])/qq)
    la_avg.append(sum([r_bb_prop[q][name][1] for q in range(qq)])/qq)

In [None]:
print('Edit Distance: Average')
print('& {} & {} & {} & {} & {}'.format(round(np.quantile(ed_avg,0.10),1),
                                        round(np.quantile(ed_avg,0.25),1),
                                        round(np.quantile(ed_avg,0.50),1),
                                        round(np.quantile(ed_avg,0.75),1),
                                        round(np.quantile(ed_avg,0.90),1)))

In [None]:
print('Lambda: Avg')
print('& {} & {} & {} & {} & {}'.format(round(np.quantile(la_avg,0.10),1),
                                        round(np.quantile(la_avg,0.25),1),
                                        round(np.quantile(la_avg,0.50),1),
                                        round(np.quantile(la_avg,0.75),1),
                                        round(np.quantile(la_avg,0.90),1)))

In [None]:
print('Lambda: Avg')
print('& {} & {} & {} & {} & {}'.format(min(la_avg),np.quantile(la_avg,0.25),np.quantile(la_avg,0.50),
                                        np.quantile(la_avg,0.75),max(la_avg)))


In [None]:
print('Edit Distance: Standard Deviation')
print('& {} & {} & {} & {} & {}'.format(min(ed_avstd),np.quantile(ed_avstd,0.25),np.quantile(ed_avstd,0.50),
                                        np.quantile(ed_avstd,0.75),max(ed_avstd)))

In [None]:
print('Lambda: Standard Deviation')
print('& {} & {} & {} & {} & {}'.format(min(la_avstd),np.quantile(la_avstd,0.25),np.quantile(la_avstd,0.50),
                                        np.quantile(la_avstd,0.75),max(la_avstd)))