In [2]:
import pandas as pd
import numpy as np

#load nodes df
nodes_df = pd.read_csv('prepped_models/cifar10_testing/node_ranks.csv')

#make wide version
nodes_wide_df = nodes_df.pivot(index = 'node_num',columns='class', values='rank_score')

def get_col(node_num, df = nodes_df, idx = 'node_num', col = 'layer'):
    return df.loc[(df[idx] == node_num) & (df['class'] == df['class'].unique()[0]), col].item()

nodes_wide_df.reset_index(inplace=True)
nodes_wide_df['layer'] = nodes_wide_df['node_num'].apply(get_col)
nodes_wide_df = nodes_wide_df.rename(columns = {'class':'index'})


#list of layer nodes
layer_nodes = {}
for row in nodes_df[nodes_df['class'] == 'overall'].itertuples(): 
    if row.layer not in layer_nodes:
        layer_nodes[row.layer] = []
    layer_nodes[row.layer].append(row.node_num)

num_layers = max(layer_nodes.keys()) + 1
num_nodes = len(nodes_wide_df.index)

#list of classes
classes = list(nodes_df['class'].unique())
classes.remove('overall')
classes.insert(0,'overall')


nodes_wide_df.head(10)



class,node_num,airplane,automobile,bird,cat,deer,dog,frog,horse,overall,ship,truck,layer
0,0,0.139876,0.193968,0.281722,0.372537,0.01882,0.267247,0.233566,0.247257,0.187421,0.102997,0.016217,0
1,1,0.089149,0.197428,0.043106,0.075388,0.018076,0.092912,0.107071,0.046974,0.088193,0.043686,0.168138,0
2,2,0.015867,0.171214,0.114838,0.220317,0.31264,0.10384,0.166537,0.129406,0.154473,0.171422,0.138648,0
3,3,0.087441,0.152154,0.076537,0.050961,0.23112,0.224538,0.013231,0.178238,0.119217,0.153142,0.024806,0
4,4,0.024439,0.037215,0.031089,0.049052,0.099022,0.073753,0.05895,0.071532,0.052207,0.044767,0.032254,0
5,5,0.201093,0.146416,0.135956,0.168017,0.077592,0.126242,0.301216,0.140233,0.171294,0.098087,0.318082,0
6,6,0.071586,0.031973,0.1855,0.242435,0.066476,0.171151,0.004216,0.050471,0.083844,0.006577,0.008058,0
7,7,0.007178,0.059125,0.052104,0.062809,0.028712,0.059987,0.026595,0.031636,0.042558,0.040607,0.056821,0
8,8,0.016677,0.077632,0.082517,0.096038,0.038893,0.038464,0.114851,0.014182,0.064695,0.092057,0.075643,0
9,9,0.181237,0.062975,0.070433,0.102134,0.045127,0.049708,0.054966,0.158716,0.079938,0.00361,0.070472,0


In [3]:
edges_df = pd.read_csv('prepped_models/cifar10_testing/edge_ranks.csv')   #load edges

#make edges wide format df
edges_wide_df = edges_df.pivot(index = 'edge_num',columns='class', values='rank_score')
edges_wide_df.reset_index(inplace=True)
edges_wide_df['layer'] = edges_wide_df['edge_num'].apply(get_col, df=edges_df,idx='edge_num')
edges_wide_df['in_channel'] = edges_wide_df['edge_num'].apply(get_col, df=edges_df,idx='edge_num',col='in_channel')
edges_wide_df['out_channel'] = edges_wide_df['edge_num'].apply(get_col, df=edges_df,idx='edge_num',col='out_channel')

print(edges_df.head(4605))
edges_wide_df.head(165)
#edges_df.loc[(edges_df['rank_score'] > .05) & (edges_df['class'] == 'frog')]

num_edges = len(edges_wide_df.index) #number of total edges


      edge_num  layer  out_channel  in_channel  rank_score  class
0            0      0            0           0    0.017843   frog
1            1      0            0           1    0.016695   frog
2            2      0            0           2    0.093749   frog
3            3      0            1           0    0.084601   frog
4            4      0            1           1    0.043617   frog
...        ...    ...          ...         ...         ...    ...
4600         0      0            0           0    0.037768  horse
4601         1      0            0           1    0.021887  horse
4602         2      0            0           2    0.052484  horse
4603         3      0            1           0    0.042027  horse
4604         4      0            1           1    0.050973  horse

[4605 rows x 6 columns]


In [4]:
#misc formatting functions

def nodeid_2_perlayerid(nodeid):    #takes in node unique id outputs tuple of layer and within layer id
    if isinstance(nodeid,str):
        if not nodeid.isnumeric():
            layer = 'img'
            within_layer_id = imgnode_names.index(nodeid)
            return layer,within_layer_id
    nodeid = int(nodeid)
    layer = nodes_df[nodes_df['class']=='overall'][nodes_df['node_num'] == nodeid]['layer'].item()
    within_layer_id = nodes_df[nodes_df['class']=='overall'][nodes_df['node_num'] == nodeid]['node_num_by_layer'].item()
    return layer,within_layer_id

def layernum2name(layer,offset=1,title = 'layer'):
    return title+' '+str(layer+offset)


In [5]:
#generate mds projections of nodes layerwise, as determined by their per class rank scores

import numpy as np
from sklearn import manifold
from sklearn.metrics import euclidean_distances

def add_norm_col(df,classes=classes[1:]):
    norms = []
    norm = 0
    for index, row in df.iterrows():
        for label in classes:
            norm += row[label]**2
        norm = np.sqrt(norm)
        norms.append(norm)
    norms = np.array(norms)
    df['class_norm'] = norms

add_norm_col(nodes_wide_df)   
    
layer_similarities = {}
for layer in layer_nodes:
    layer_df = nodes_wide_df[nodes_wide_df['layer'] == layer]
    for label in classes:
        layer_df[label] = layer_df.apply(lambda row : row[label]/row['class_norm'], axis = 1)   
    layer_similarities[layer] = euclidean_distances(layer_df.iloc[:,1:-2])



layer_mds = {}
for layer in layer_similarities:
	print('layer: %s'%str(layer))
	mds = manifold.MDS(n_components=2, max_iter=3000, eps=1e-9, 
      random_state=2, dissimilarity="precomputed", n_jobs=1)
	pos = mds.fit(layer_similarities[layer]).embedding_
	layer_mds[layer] = pos

#print(layer_mds)





A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user

layer: 0
layer: 1
layer: 2
layer: 3


In [6]:
#generate node colors based on target class (nodes that aren't important should be faded)

target_class = 'airplane'

#Node Opacity
layer_colors = ['rgba(31,119,180,', 
                'rgba(255,127,14,',
                'rgba(44,160,44,', 
                'rgba(214,39,40,',
                'rgba(39, 208, 214,', 
                'rgba(242, 250, 17,',
                'rgba(196, 94, 255,',
                'rgba(193, 245, 5,',
                'rgba(245, 85, 5,',
                'rgba(5, 165, 245,',
                'rgba(245, 5, 105,',
                'rgba(218, 232, 23,',
                'rgba(148, 23, 232,',
                'rgba(23, 232, 166,',]


def node_color_scaling(x):
    return -(x-1)**4+1

def gen_node_colors(target_class):

    node_colors = {}
    for layer in layer_nodes:
        node_colors[layer] = []
        for node in layer_nodes[layer]:
            alpha = node_color_scaling(nodes_df[nodes_df['class']==target_class].iloc[node].rank_score)
            node_colors[layer].append(layer_colors[layer%len(layer_colors)]+str(round(alpha,3))+')')
    return node_colors

node_colors = gen_node_colors(target_class)

#print(colors_dict)

In [7]:
#Node positions
#def gen_node_positions()
layer_distance = 1   # distance in X direction each layer is separated by
node_positions = {}
layer_offset = 0
for layer in layer_mds:
    node_positions[layer] = {}
    node_positions[layer]['X'] = [] 
    node_positions[layer]['Y'] = [] 
    node_positions[layer]['Z'] = []  
    for i in range(len(layer_mds[layer])): 
        node_positions[layer]['Y'].append(layer_mds[layer][i][0])
        node_positions[layer]['Z'].append(layer_mds[layer][i][1])
        node_positions[layer]['X'].append(layer_offset)
    layer_offset+=1*layer_distance

#print(node_positions[0])

In [8]:
#image nodes (one for each channel of input image)

num_img_chan = len(edges_df.loc[edges_df['layer'] == 0]['in_channel'].unique()) #number of channels in input image

def gen_imgnode_graphdata(num_chan = num_img_chan):     #returns positions, colors and names for imgnode graph points
    if num_chan == 1: #return a centered position, grey square, with 'gs' label
        return {'X':[-1],'Y':[0],'Z':[0]}, ['rgba(170,170,170,.9)'], ['gs']
    if num_chan == 3:
        colors = ['rgba(255,0,0,.9)','rgba(0,255,0,.9)','rgba(0,0,255,.9)']
        names = ['r','g','b']
    else:
        #colors
        other_colors = ['rgba(255,0,0,.9)','rgba(0,255,0,.9)','rgba(0,0,255,.9)',
                        'rgba(255,150,0,.9)','rgba(0,255,150,.9)','rgba(150,0,255,.9)',
                        'rgba(255,0,150,.9)','rgba(150,255,0,.9)','rgba(0,150,255,.9)']
        colors = []
        for i in num_chan:
            colors.append(i%len(other_colors)) 
        #names
        names = []
        for i in range(num_chan):
            names.append('img_'+str(i))   
            
    positions = {'X':[],'Y':[],'Z':[]}     #else return points evenly spaced around a unit circle
    a = 2*np.pi/num_chan          #angle to rotate each point
    for p in range(num_chan):
        positions['X'].append(-1)
        positions['Y'].append(np.sin(a*p))
        positions['Z'].append(np.cos(a*p)) 
    
    return positions, colors, names
    
# def gen_imgnode_positions(num_chan = num_img_chan):

#     if num_chan == 1: #return a centered position
#         return {'X':[-1],'Y':[0],'Z':[0]}
    
#     positions = {'X':[],'Y':[],'Z':[]}     #else return points evenly spaced around a unit circle
#     a = 2*np.pi/num_chan          #angle to rotate each point
#     for p in range(num_chan):
#         positions['X'].append(-1)
#         positions['Y'].append(np.sin(a*p))
#         positions['Z'].append(np.cos(a*p))
#     return positions

# def gen_imgnode_colors(num_chan = num_img_chan):
#     if num_chan == 1:
#         return ['rgba(170,170,170,.9)']   #grey
#     elif num_chan == 3:
#         return ['rgba(255,0,0,.9)','rgba(0,255,0,.9)','rgba(0,0,255,.9)']             #rgb

#     else:
#         other_colors = ['rgba(255,0,0,.9)','rgba(0,255,0,.9)','rgba(0,0,255,.9)',
#                         'rgba(255,150,0,.9)','rgba(0,255,150,.9)','rgba(150,0,255,.9)',
#                         'rgba(255,0,150,.9)','rgba(150,255,0,.9)','rgba(0,150,255,.9)']
#         colors = []
#         for i in num_chan:
#             colors.append(i%len(other_colors))
#         return colors
    
# def gen_imgnode_names(num_chan = num_img_chan):
#     if num_chan == 1:
#         return ['gs'] #grayscale
#     elif num_chan == 3:
#         return ['r','g','b']
#     else:
#         names = []
#         for i in range(num_chan):
#             names.append('img_'+str(i))
#         return names

imgnode_positions,imgnode_colors,imgnode_names = gen_imgnode_graphdata()

print(imgnode_positions)

{'X': [-1, -1, -1], 'Y': [0.0, 0.8660254037844388, -0.8660254037844384], 'Z': [1.0, -0.4999999999999998, -0.5000000000000004]}


In [48]:
# #Edge selection

# #edges_df.loc[(edges_df['rank_score'] > .05) & (edges_df['class'] == 'frog')]
# def get_thresholded_edges(threshold=.1,df=edges_df,target_class=target_class):          #just get those edges that pass the threshold criteria for the target class
#     return edges_df.loc[(edges_df['rank_score'] > threshold) & (edges_df['class'] == target_class)]



# def gen_edge_subset_and_weights(target_class,edge_threshold=.1):
#     edge_weights = {}
#     #edge_threshold = .1
#     Edges = {}
#     for layer in Edges_full:
#         Edges[layer] = []
#         edge_weights[layer] = []
#         for i in range(len(Edges_full[layer])):
#             edge_weight = nodes_df[nodes_df['class']==target_class].iloc[Edges_full[layer][i][0]].rank_score*nodes_df[nodes_df['class']==target_class].iloc[Edges_full[layer][i][1]].rank_score
#             if edge_weight > edge_threshold:
#                 Edges[layer].append(Edges_full[layer][i])
#                 edge_weights[layer].append(edge_weight)
#     return Edges, edge_weights

# Edges,edge_weights = gen_edge_subset_and_weights(target_class)
            
# #print(Edges)

# #Edge Positions
# def gen_edge_positions(Edges):
#     edge_positions = {}
#     for layer in Edges:
#         edge_positions[layer] = {}
#         edge_positions[layer]['X'] = []
#         edge_positions[layer]['Y'] = []
#         edge_positions[layer]['Z'] = []
#         for edge in Edges[layer]:
#             edge_positions[layer]['X']+=([node_positions[layer-1]['X'][edge[2]],node_positions[layer]['X'][edge[3]], None])# x-coordinates of edge ends
#             edge_positions[layer]['Y']+=([node_positions[layer-1]['Y'][edge[2]],node_positions[layer]['Y'][edge[3]], None])
#             edge_positions[layer]['Z']+=([node_positions[layer-1]['Z'][edge[2]],node_positions[layer]['Z'][edge[3]], None])    
#     return edge_positions

# edge_positions = gen_edge_positions(Edges)

# #print(edge_positions)

# print('Edges')
# print(Edges)
# print('edge_weights')
# print(edge_weights)
# print('edge_positions')
# print(edge_positions)

# #Edge Colors
# edge_colors_dict = {}
# for layer in Edges:
#     edge_colors_dict[layer] = []
#     for weight in edge_weights[layer]:
#         alpha = color_scaling(weight)
#         edge_colors_dict[layer].append(layer_colors[layer%len(layer_colors)]+str(round(alpha,3))+')')

Edges
{1: [(5, 48, 5, 3), (9, 48, 9, 3), (13, 48, 13, 3), (16, 46, 16, 1), (16, 48, 16, 3), (18, 46, 18, 1), (18, 48, 18, 3), (18, 49, 18, 4), (18, 52, 18, 7), (18, 56, 18, 11), (18, 62, 18, 17), (18, 64, 18, 19), (19, 46, 19, 1), (19, 48, 19, 3), (26, 46, 26, 1), (26, 48, 26, 3), (31, 46, 31, 1), (31, 48, 31, 3), (31, 49, 31, 4), (31, 56, 31, 11), (31, 62, 31, 17), (31, 64, 31, 19), (36, 48, 36, 3), (41, 48, 41, 3), (43, 48, 43, 3)], 2: [(46, 73, 1, 5), (46, 90, 1, 22), (46, 92, 1, 24), (46, 94, 1, 26), (46, 100, 1, 32), (46, 113, 1, 45), (48, 73, 3, 5), (48, 90, 3, 22), (48, 92, 3, 24), (48, 94, 3, 26), (48, 100, 3, 32), (48, 109, 3, 41), (48, 110, 3, 42), (48, 113, 3, 45), (49, 100, 4, 32), (52, 100, 7, 32), (56, 100, 11, 32), (62, 90, 17, 22), (62, 92, 17, 24), (62, 100, 17, 32), (62, 113, 17, 45), (64, 100, 19, 32)], 3: [(73, 143, 5, 26), (90, 143, 22, 26), (92, 143, 24, 26), (94, 143, 26, 26), (100, 121, 32, 4), (100, 123, 32, 6), (100, 140, 32, 23), (100, 143, 32, 26), (100, 156

In [9]:
#Edge selection

def edge_width_scaling(x):
    return max(.5,(x*5)**1.5)

def edge_color_scaling(x):
    return max(.4,-(x-1)**4+1)


def get_thresholded_edges(threshold=.1,df=edges_df,target_class=target_class):          #just get those edges that pass the threshold criteria for the target class
    return edges_df.loc[(edges_df['rank_score'] > threshold) & (edges_df['class'] == target_class)]

edges_select_df = get_thresholded_edges()


def gen_edge_graphdata(df = edges_select_df, node_positions = node_positions, num_hoverpoints=15,target_class=target_class):
    edge_positions = {}
    colors = {}
    widths = {}
    names = {}
    for row in edges_select_df.itertuples():
        if row.layer not in edge_positions:
            edge_positions[row.layer] = {'X':[],'Y':[],'Z':[]}
            colors[row.layer] = []
            widths[row.layer] = []
            names[row.layer] = []        
        #position
        for dim in ['X','Y','Z']:
            end_pos = node_positions[row.layer][dim][row.out_channel]
            if row.layer != 0:
                start_pos = node_positions[row.layer-1][dim][row.in_channel]
            else:
                start_pos = imgnode_positions[dim][row.in_channel]
            
            step = (end_pos-start_pos)/(num_hoverpoints+1)
            points = [start_pos]
            for i in range(1,num_hoverpoints+1):
                points.append(start_pos+i*step)
            points.append(end_pos)
            edge_positions[row.layer][dim].append(points)
        #color
        alpha = edge_color_scaling(row.rank_score)
        colors[row.layer].append(layer_colors[row.layer%len(layer_colors)]+str(round(alpha,3))+')')
        #width
        widths[row.layer].append(edge_width_scaling(row.rank_score))
        #names
        out_node = layer_nodes[row.layer][row.out_channel]
        if row.layer != 0:
            in_node = layer_nodes[row.layer-1][row.in_channel]
        else:
            in_node = imgnode_names[row.in_channel]
        names[row.layer].append(str(in_node)+'-'+str(out_node))
    return edge_positions, colors,widths,names


edge_positions, edge_colors, edge_widths, edge_names = gen_edge_graphdata()

print(edge_positions)
print('\n\n')
print(edge_colors)
print('\n\n')
print(edge_widths)
print('\n\n')
print(edge_names)

{0: {'X': [[-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625, 0], [-1, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.

In [11]:
#Format Node Feature Maps

# import pickle

# activations = pickle.load(open('activations/cifar_prunned_.816_activations.pkl','rb'))
  
    
# node_ids = []
# for layer in layer_nodes:
#     for i in range(len(layer_nodes[layer])):
#         node_ids.append(str(layer+1)+'_'+str(i))
    
# print(activations['airplane']['0001.png'][0].shape)

import torch
activations = torch.load('prepped_models/cifar10_testing/input_img_activations.pt')

print(activations['edges'][0].shape)
print(activations['nodes'][0].shape)

(100, 45, 3, 32, 32)
(100, 45, 32, 32)


In [13]:
#Format Edge Kernels

# kernels = pickle.load(open('kernels/cifar_prunned_.816_kernels.pkl','rb'))
# print(kernels[0].shape)

kernels = torch.load('prepped_models/cifar10_testing/kernels.pt')

print(kernels[0].shape)

#Function for taking a string of form 'node1-node2' and outputting edge info
def nodestring_2_edge_info(nodestring):
    from_node = nodestring.split('-')[0]
    to_node = nodestring.split('-')[1]
    from_layer,from_within_id = nodeid_2_perlayerid(from_node)
    to_layer,to_within_id = nodeid_2_perlayerid(to_node)
    kernel = kernels[to_layer][to_within_id][from_within_id]
    return np.flip(kernel,0)


print(nodestring_2_edge_info('b-0'))
print(np.flip(kernels[0][0][2],0))

(45, 3, 3, 3)
[[-0.01040168  0.07258165 -0.02730235]
 [-0.09874548  0.15988253  0.16769743]
 [-0.11188427  0.21347913 -0.24422328]]
[[-0.01040168  0.07258165 -0.02730235]
 [-0.09874548  0.15988253  0.16769743]
 [-0.11188427  0.21347913 -0.24422328]]




In [14]:
## adding images
import glob
import os

input_image_directory = 'input_images_testing/'
list_of_input_images = [os.path.basename(x) for x in glob.glob('{}*.png'.format(input_image_directory))]

static_input_image_route = '/static_input_images/'


# edge_image_directory = '/Users/chrishamblin/Desktop/graph_viz/edge_images/'
# list_of_edge_images = [os.path.basename(x) for x in glob.glob('{}*.png'.format(edge_image_directory))]

# edge_static_image_route = '/static_edge/'
#list_of_input_images[0]


In [15]:
edge_positions

{0: {'X': [[-1,
    -0.9375,
    -0.875,
    -0.8125,
    -0.75,
    -0.6875,
    -0.625,
    -0.5625,
    -0.5,
    -0.4375,
    -0.375,
    -0.3125,
    -0.25,
    -0.1875,
    -0.125,
    -0.0625,
    0],
   [-1,
    -0.9375,
    -0.875,
    -0.8125,
    -0.75,
    -0.6875,
    -0.625,
    -0.5625,
    -0.5,
    -0.4375,
    -0.375,
    -0.3125,
    -0.25,
    -0.1875,
    -0.125,
    -0.0625,
    0],
   [-1,
    -0.9375,
    -0.875,
    -0.8125,
    -0.75,
    -0.6875,
    -0.625,
    -0.5625,
    -0.5,
    -0.4375,
    -0.375,
    -0.3125,
    -0.25,
    -0.1875,
    -0.125,
    -0.0625,
    0],
   [-1,
    -0.9375,
    -0.875,
    -0.8125,
    -0.75,
    -0.6875,
    -0.625,
    -0.5625,
    -0.5,
    -0.4375,
    -0.375,
    -0.3125,
    -0.25,
    -0.1875,
    -0.125,
    -0.0625,
    0],
   [-1,
    -0.9375,
    -0.875,
    -0.8125,
    -0.75,
    -0.6875,
    -0.625,
    -0.5625,
    -0.5,
    -0.4375,
    -0.375,
    -0.3125,
    -0.25,
    -0.1875,
    -0.125,
    -0.0625,


In [45]:
#import chart_studio.plotly as py
import plotly.offline as py    #added
import plotly.graph_objs as go
py.init_notebook_mode(connected=True)   #added



#add imgnodes
imgnode_trace=go.Scatter3d(x=imgnode_positions['X'],
           y=imgnode_positions['Y'],
           z=imgnode_positions['Z'],
           mode='markers',
           name=layernum2name(layer,title = 'nodes'),
           marker=dict(symbol='square',
                         size=8,
                         opacity=.99,
                         color=imgnode_colors,
                         #colorscale='Viridis',
                         line=dict(color='rgb(50,50,50)', width=.5)
                         ),
           text=imgnode_names,
           hoverinfo='text'
           )

imgnode_traces = [imgnode_trace]


node_traces = []
for layer in layer_nodes:
    #add nodes
    node_trace=go.Scatter3d(x=node_positions[layer]['X'],
               y=node_positions[layer]['Y'],
               z=node_positions[layer]['Z'],
               mode='markers',
               name=layernum2name(layer,title = 'nodes'),
               marker=dict(symbol='circle',
                             size=6,
                             opacity=.99,
                             color=node_colors[layer],
                             #colorscale='Viridis',
                             line=dict(color='rgb(50,50,50)', width=.5)
                             ),
               text=layer_nodes[layer],
               hoverinfo='text'
               )
        
    node_traces.append(node_trace)
    
edge_traces = []    
for layer in edge_positions:        
    #add edges      
    edge_trace=go.Scatter3d(x=edge_positions[layer]['X'],
                            y=edge_positions[layer]['Y'],
                            z=edge_positions[layer]['Z'],
                            name=layernum2name(layer ,title = 'edges'),
                            mode='lines',
                            #line=dict(color=edge_colors_dict[layer], width=1.5),
                            line=dict(color='rgb(100,100,100)', width=1.5),
                            text = list(range(len(Edges[layer])))
                            #hoverinfo='text'
                            )
    edge_traces.append(edge_trace)

 
combined_traces = imgnode_traces+node_traces+edge_traces


#layout
axis=dict(showbackground=False,
          showspikes=False,
          showline=False,
          zeroline=False,
          showgrid=False,
          showticklabels=False,
          #range=[0,0],
          title=''
          )

graph_layout = go.Layout(
         #title="%s through Prunned Cifar10 CNN"%target_class,
         #title = target_class,
         #width=1000,
         clickmode = 'event+select',
         transition = {'duration': 500},
         height=600,
         #showlegend=False,
         margin = dict(l=20, r=20, t=20, b=20),
         scene=dict(
             xaxis=dict(axis),
             yaxis=dict(axis),
             zaxis=dict(axis),
         ),
         uirevision =  True   
         #hovermode='closest',
   )


fig=go.Figure(data=combined_traces, layout=graph_layout)




In [46]:
import dash
import dash_core_components as dcc
import dash_html_components as html
#import utils.dash_reusable_components as drc
import flask
import os

from dash.dependencies import Input, Output, State


#external_stylesheets = ['https://codepen.io/amyoshino/pen/jzXypZ.css']
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(external_stylesheets = external_stylesheets)


styles = {
    'pre': {
        'border': 'thin lightgrey solid',
        'overflowX': 'scroll'
    }
}


app.layout = html.Div(
        [html.Div(         #Left side control panel
            children = [
             html.Label('Weighting Category'),
             dcc.Dropdown(
                id='weight-category',
                options=[{'label': i, 'value': i} for i in classes],
                value=target_class
                ),
             html.Br(),
             html.Label('Weighting Criterion'),
             dcc.Dropdown(
                id='weight-criterion',
                options=[
                    {'label': 'Activations*Grads', 'value': 'actgrads'},
                    {'label': 'Activations', 'value': 'acts'}
                ],
                value='actgrads'
                ),
             html.Br(),   
             html.Label('Layer Projection'),
             dcc.Dropdown(
                id = 'layer-projection',
                options=[
                    {'label': 'MDS', 'value': 'MDS'},
                    {'label': 'Grid', 'value': 'grid'},
                    #{'label': 'SOM', 'value': 'SOM'}
                ],
                value='MDS'
                ),

            html.Br(),
            html.Label('Lower Edge Threshold'),
                dcc.Slider(
                    id='lower-thresh-slider',
                    min=0,
                    max=1,
                    step=0.01,
                    marks={i/10: str(i/10) for i in range(0,12,2)},
                    value=.1,
                ),
                
            html.Br(),
            html.Label('Upper Edge Threshold'),
                dcc.Slider(
                    id='upper-thresh-slider',
                    min=0,
                    max=1,
                    step=0.01,
                    marks={i/10: str(i/10) for i in range(0,12,2)},
                    value=1,
                ),
                
            ], className="two columns"
        ),

        html.Div(
            children = [
                
            html.Div([
                dcc.Graph(
                    id='network-graph',
                    figure=fig
                )
            ], className= 'row'
            ),
                
            html.Div([
                html.Div([
                html.Label('Input Image'),
                dcc.Dropdown(
                    id='input-image-dropdown',
                    options=[{'label': i, 'value': i} for i in list_of_input_images],
                    value=list_of_input_images[6]
                ),
                html.Br(),
                html.Br(),
                html.Br(),
                html.Br(),
                html.Img(id='input-image')#,style={'height':'200%', 'width':'200%'}
                ], className = "three columns"),
                
                html.Div([
                html.Label('Node'),
                dcc.Dropdown(
                    id='node-actmap-dropdown',
                    options=[{'label': str(i), 'value': i} for i in range(num_nodes)],
                    value=0
                ),
                dcc.Graph(
                    id='node-actmap-graph',
                    figure=go.Figure(data=go.Heatmap(
                                        z = np.flip(activations[list_of_input_images[0].split('_')[0]][list_of_input_images[0].split('_')[1]][0][0],0)),
                                        layout=dict(
                                            height=500,
                                            width=500)
                                    ),
                    config={
                            'displayModeBar': False
                            }
                )
                ], className = "three columns"),
                
                html.Div([
                html.Label('Edge'),    
                dcc.Input(
                    id='edge-kernel-input',value='0-%s'%str(layer_nodes[1][0]), type='text'),
                html.Button(id='edge-kernel-button',n_clicks=0, children='Submit'),
                dcc.Graph(
                    id='edge-kernel-graph',
                    figure=go.Figure(data=go.Heatmap(
                                        z = nodestring_2_edge_info('0-%s'%str(layer_nodes[1][0]))
                                        ),
                                        layout=dict(
                                            height=500,
                                            width=500)
                            ),
                    config={
                            'displayModeBar': False
                            }
                )
                ], className = "three columns")
                
                
             ], className= 'row'
             ),
                
                
            html.Div([
                html.Div([
                    dcc.Markdown("""
                        **Hover Data**

                        Mouse over values in the graph.
                    """),
                    html.Pre(id='hover-data', style=styles['pre'])
                ], className='two columns'),

                html.Div([
                    dcc.Markdown("""
                        **Click Data**

                        Click on points in the graph.
                    """),
                    html.Pre(id='click-data', style=styles['pre']),
                ], className='two columns'),

                html.Div([
                    dcc.Markdown("""
                        **Selection Data**

                        Choose the lasso or rectangle tool in the graph's menu
                        bar and then select points in the graph.

                        Note that if `layout.clickmode = 'event+select'`, selection data also 
                        accumulates (or un-accumulates) selected data if you hold down the shift
                        button while clicking.
                    """),
                    html.Pre(id='selected-data', style=styles['pre']),
                ], className='two columns'),

#                 html.Div([
#                     dcc.Markdown("""
#                         **Zoom and Relayout Data**

#                         Click and drag on the graph to zoom or click on the zoom
#                         buttons in the graph's menu bar.
#                         Clicking on legend items will also fire
#                         this event.
#                     """),
#                     html.Pre(id='relayout-data', style=styles['pre']),
#                 ], className='two columns')
                
                html.Div([
                    dcc.Markdown("""
                        **Figure Data**

                        Figure json info.
                    """),
                    html.Pre(id='figure-data', style=styles['pre']),
                ], className='four columns')
                
            ], className= 'row'
            )
        ], className="ten columns"
        )
    ]
)




####Call Back Functions

# @app.callback(
#     Output('figure-data', 'children'),
#     [Input('network-graph', 'figure')])
# def display_figure_data(figure):
#     return json.dumps(figure, indent=2)



@app.callback(
    Output('network-graph', 'figure'),
    [Input('network-graph', 'clickData')])
def highlight_on_click(clickData):
    if clickData['points'][0]['curveNumber'] == None:
        raise Exception('no point clicked yet') 
    trace_num = int(clickData['points'][0]['curveNumber'])
    if trace_num < num_layers:   #highlight point
        for layer in node_colors:
            if layer == trace_num:
                new_colors = list(node_colors[trace_num])
                new_colors[clickData['points'][0]['pointNumber']] = 'rgba(0,0,0,1)'
                combined_traces[trace_num]['marker']['color'] = new_colors
            else:
                combined_traces[layer]['marker']['color'] = node_colors[layer]
    else: #highlight edge
        #raise Exception('lets skip edges for now') 
        for layer in Edges:
            new_colors = list(['rgb(125,125,125)' for i in range(len(Edges[layer]))])
            #new_colors = edge_colors_dict[layer]
            if layer == trace_num-num_layers+1:
                new_colors[clickData['points'][0]['text']] = 'rgba(150,0,0,1)'
            combined_traces[layer]['line']['color'] = new_colors
    
    layout = graph_layout
    layout['uirevision'] = True
    return {'data': combined_traces,
            'layout': layout}



@app.callback(
    Output('node-actmap-dropdown', 'value'),
    [Input('network-graph', 'clickData')])
def switch_node_actmap_click(clickData):
    if clickData['points'][0]['curveNumber'] == None:
        raise Exception('no point clicked yet') 
    if int(clickData['points'][0]['curveNumber']) >= num_layers:
        raise Exception('Do nothing, they clicked an edge')
    return int(clickData['points'][0]['text'])


         
#cant currently click edges
# @app.callback(
#     Output('edge-image-dropdown', 'value'),
#     [Input('network-graph', 'clickData')])
# def switch_edge_image_click(clickData):
#     if int(clickData['points'][0]['curveNumber']) < num_layers:
#         raise Exception('Do nothing, they clicked a node')
#     return list_of_edge_images[int(clickData['points'][0]['pointNumber'])]



#Node activation map
@app.callback(
    Output('node-actmap-graph', 'figure'),
    [Input('node-actmap-dropdown', 'value'),
     Input('input-image-dropdown', 'value')])
def update_node_actmap(nodeid,image_name):
    layer, within_id = nodeid_2_perlayerid(nodeid)
    if layer == 'img':
        pass
        #code for returning color channel as activation map
        
    return go.Figure(data=go.Heatmap( z = np.flip(activations[image_name.split('_')[0]][image_name.split('_')[1]][layer][within_id],0)),
                     layout=dict(height=500,
                                 width=500,
                                 uirevision=True)) 
#     return {'data':go.Heatmap(
#                               z = activations[image_name.split('_')[0]][image_name.split('_')[1]][layer][within_id]),
#             'layout':dict(height=500,width=500)}
 
    
      

@app.callback(
    Output('edge-kernel-graph', 'figure'),
    [Input('edge-kernel-button','n_clicks')],
    [State('edge-kernel-input', 'value')])
def update_edge_kernelmap(n_clicks,nodestring):
    return go.Figure(data=go.Heatmap(z = nodestring_2_edge_info(nodestring)),
                     layout=dict(height=500,
                                 width=500,
                                 uirevision=True)) 
                

#Input Images
@app.callback(
    Output('input-image', 'src'),
    [Input('input-image-dropdown', 'value')])
def update_input_image_src(value):
    return static_input_image_route + value

@app.server.route('{}<image_path>.png'.format(static_input_image_route))
def serve_input_image(image_path):
    image_name = '{}.png'.format(image_path)
    if image_name not in list_of_input_images:
        raise Exception('"{}" is excluded from the allowed static files'.format(image_path))
    return flask.send_from_directory(input_image_directory, image_name)




#JSON INFO

@app.callback(
    Output('hover-data', 'children'),
    [Input('network-graph', 'hoverData')])
def display_hover_data(hoverData):
    return json.dumps(hoverData, indent=2)




@app.callback(
    Output('click-data', 'children'),
    [Input('network-graph', 'clickData')])
def display_click_data(clickData):
    return json.dumps(clickData, indent=2)


@app.callback(
    Output('selected-data', 'children'),
    [Input('network-graph', 'selectedData')])
def display_selected_data(selectedData):
    return json.dumps(selectedData, indent=2)



# @app.callback(
#     Output('network-graph', 'figure'),
#     [Input('weight-category', 'value'),
#      Input('network-graph', 'clickData'),
#      Input('lower-thresh-slider','value')])
# def update_figure(target_class,clickData,edge_thresh):
#     node_colors = gen_node_colors(target_class)
#     Edges,edge_weights = gen_edge_subset_and_weights(target_class,edge_threshold=edge_thresh)
#     edge_positions = gen_edge_positions(Edges)
#     click_layer = int(clickData['points'][0]['curveNumber'])
#     for layer in node_colors:
#         if layer == click_layer:
#             new_colors = list(node_colors[click_layer])
#             new_colors[clickData['points'][0]['pointNumber']] = 'rgba(0,0,0,1)'
#             combined_traces[layer]['marker']['color'] = new_colors
#         else:
#             combined_traces[layer]['marker']['color'] = node_colors[layer]

#     for layer in edge_positions:
#         combined_traces[layer-1+num_layers] = go.Scatter3d(x=edge_positions[layer]['X'],
#                                 y=edge_positions[layer]['Y'],
#                                 z=edge_positions[layer]['Z'],
#                                 name=layernum2name(layer ,title = 'edges'),
#                                 mode='lines',
#                                 #line=dict(color=edge_colors_dict[layer], width=1.5),
#                                 line=dict(color='rgb(100,100,100)', width=1.5),
#                                 text = list(range(len(Edges[layer])))
#                                 #hoverinfo='text'
#                                 )
   
#     layout = graph_layout
#     layout['uirevision'] = True
#     return {'data': combined_traces,
#             'layout': layout}






# @app.callback(
#     Output('relayout-data', 'children'),
#     [Input('network-graph', 'relayoutData')])
# def display_relayout_data(relayoutData):
#     return json.dumps(relayoutData, indent=2)



# Add a static image route that serves images from desktop
# Be *very* careful here - you don't want to serve arbitrary files
# from your computer or server

# @app.callback(Output('my-div', 'children'),
#                      [Input('my-input', 'value')],
#                      [State('my-div', 'children')])
# def update_div(value, existing_state):
#     if some_condition:
#          return existing_state


#py.iplot(fig, filename='small_net')



#             html.Label('Multi-Select Dropdown'),
#             dcc.Dropdown(
#                 options=[
#                     {'label': 'New York City', 'value': 'NYC'},
#                     {'label': u'Montréal', 'value': 'MTL'},
#                     {'label': 'San Francisco', 'value': 'SF'}
#                 ],
#                 value=['MTL', 'SF'],
#                 multi=True
#                 ),

#             html.Label('Radio Items'),
#             dcc.RadioItems(
#                 options=[
#                     {'label': 'New York City', 'value': 'NYC'},
#                     {'label': u'Montréal', 'value': 'MTL'},
#                     {'label': 'San Francisco', 'value': 'SF'}
#                 ],
#                 value='MTL'
#                 ),

# @app.callback(Output('output-state', 'children'),
#               [Input('submit-button-state', 'n_clicks')],
#               [State('input-1-state', 'value'),
#                State('input-2-state', 'value')])
# def update_output(n_clicks, input1, input2):
#     return u'''
#         The Button has been pressed {} times,
#         Input 1 is "{}",
#         and Input 2 is "{}"
#     '''.format(n_clicks, input1, input2)
    
    
# @app.server.route('{}<image_path>.png'.format(node_static_image_route))
# def serve_node_image(image_path):
#     image_name = '{}.png'.format(image_path)
#     if image_name not in list_of_node_images:
#         raise Exception('"{}" is excluded from the allowed static files'.format(image_path))
#     return flask.send_from_directory(node_image_directory, image_name)



Boolean Series key will be reindexed to match DataFrame index.


Boolean Series key will be reindexed to match DataFrame index.



In [47]:
app.run_server()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET /_dash-component-suites/dash_renderer/react@16.v1_2_2m1580842230.8.6.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET /_dash-component-suites/dash_renderer/polyfill@7.v1_2_2m1580842230.7.0.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET /_dash-component-suites/dash_renderer/prop-types@15.v1_2_2m1580842230.7.2.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET /_dash-component-suites/dash_renderer/react-dom@16.v1_2_2m1580842230.8.6.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET /_dash-component-suites/dash_html_components/dash_html_components.v1_0_2m1573845875.min.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:05] "[37mGET /_dash-component-suites/dash_core_components/dash_core_components.v1_8_1m1582848776

Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 2446, in wsgi_app
    response = self.full_dispatch_request()
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1951, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1820, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
    raise value
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1949, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1935, in d

127.0.0.1 - - [08/Jun/2020 17:53:08] "[35m[1mPOST /_dash-update-component HTTP/1.1[0m" 500 -


Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 2446, in wsgi_app
    response = self.full_dispatch_request()
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1951, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1820, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
    raise value
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1949, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Users/chrishamblin/miniconda3/envs/graph_viz/lib/python3.7/site-packages/flask/app.py", line 1935, in d

127.0.0.1 - - [08/Jun/2020 17:53:08] "[35m[1mPOST /_dash-update-component HTTP/1.1[0m" 500 -
127.0.0.1 - - [08/Jun/2020 17:53:09] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:23] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:23] "[37mGET /static_input_images/airplane_0002.png HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:23] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:26] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:26] "[37mGET /static_input_images/airplane_0003.png HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:26] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:40] "[37mPOST 

127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:54] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:53:57] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:57] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:57] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:57] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:57] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:53:57] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:02] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:02] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:02] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:02] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:02] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:12] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:21] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:54:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:46] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:46] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:46] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:46] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:46] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:51] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:54] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:54:59] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 17:55:04] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 22:59:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:04] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:05] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:05] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:05] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:05] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:05] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 22:59:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:11] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:12] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:12] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:20] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:23] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:23] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:23] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:23] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 22:59:29] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:30] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 22:59:32] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:32] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:32] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:32] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:33] "[37mPOST /_dash-update-component HTTP/1.1

127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:37] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2020 22:59:38] "[37mPOST /_dash-update-component HTTP/1.1

In [216]:
image_name = 'airplane_0001.png'
layer, within_id = nodeid_2_perlayerid(149)
#print(activations[image_name.split('_')[0]][image_name.split('_')[1]])

[array([[[ 1.2824656 ,  0.8565904 ,  0.8818083 , ...,  0.87015975,
          0.855741  ,  0.7828878 ],
        [ 1.2424941 ,  0.23603298,  0.2525151 , ...,  0.23459584,
          0.22942215,  0.64034116],
        [ 1.301531  ,  0.2654904 ,  0.2696002 , ...,  0.24644026,
          0.2499801 ,  0.66129947],
        ...,
        [ 1.521893  ,  0.34898174,  0.3139554 , ...,  0.30267176,
          0.29700178,  0.8708159 ],
        [ 1.4859303 ,  0.3290515 ,  0.44562238, ...,  0.2992803 ,
          0.2746499 ,  0.86019945],
        [ 1.7619598 ,  0.46666557,  0.54549456, ...,  0.6023208 ,
          0.5483697 ,  0.691553  ]],

       [[-0.03734329,  0.17756009,  0.1765095 , ...,  0.18297783,
          0.17204088,  0.16133627],
        [-0.37692112, -0.08926529, -0.08464667, ..., -0.09308489,
         -0.09154018,  0.12175873],
        [-0.4008486 , -0.11605299, -0.10956731, ..., -0.07384205,
         -0.07937238,  0.12824655],
        ...,
        [-0.29320723,  0.17471272,  0.42962372, ..., 


Boolean Series key will be reindexed to match DataFrame index.


Boolean Series key will be reindexed to match DataFrame index.



In [295]:
print(edge_weights)

{1: [0.13393096771897106, 0.17065674557004942, 0.22843438109956082, 0.21554536803804325, 0.21937029231531824, 0.21865732496510404, 0.1143273308230981, 0.11066470293216926, 0.3083848996134648, 0.11906178153320646, 0.12250828264472524, 0.2549085819787109, 0.10126440253318592], 2: [0.10364892214199276, 0.11814751072434504, 0.2888343040092387, 0.12981509666264301, 0.10236099485570092, 0.1261464246642916, 0.13022122981671022, 0.12924746744562124, 0.22363833715187909, 0.2201996633475538, 0.17473984297618372, 0.12294931867272751, 0.1355681624089815, 0.13662098730953784, 0.10326656722572025, 0.11225775284586524, 0.11832326072405053, 0.11564111721386361, 0.11151365338039465, 0.11474165757599675], 3: [0.12264937774880827, 0.1127526765768474, 0.24119884152364257, 0.10840558234491304, 0.10534195927919754, 0.10874473474096023, 0.1079315683249016, 0.18675519871622193, 0.18388364181848438, 0.14592119810163062, 0.10267213007023734, 0.11320983438133947, 0.11408902408566667]}


In [339]:
edge_positions

{1: {'X': [0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None,
   0,
   1,
   None],
  'Y': [-0.3580247738363428,
   -0.11742886811595352,
   None,
   -0.3580247738363428,
   0.5119787537211334,
   None,
   -0.3580247738363428,
   0.08629212182261396,
   None,
   -0.14316565219696478,
   -0.11742886811595352,
   None,
   -0.14316565219696478,
   0.5119787537211334,
   None,
   -0.22366249474517383,
   -0.11742886811595352,
   None,
   -0.22366249474517383,
   0.5119787537211334,
   None,
   -0.22366249474517383,
   0.08629212182261396,
   None,
   -0.223662494745173

In [None]:
#Hide zoom, pan thingy

dcc.Graph(
    id='my-graph',
    figure={'data': [{'x': [1, 2, 3]}]},
    config={
        'displayModeBar': False
    }
)

In [119]:
#image to numpy array
from PIL import Image
import numpy
im = Image.open("edge_images/0164.png").convert('LA')
np_im = numpy.array(im)
np_im = np_im[:,:,0]
np_im = np.flip(np_im,0)

#numpy array to plotly heatmap
fig = go.Figure(data=go.Heatmap(
                    z = activations['airplane']['0001.png'][3][0]),
                    #z=np_im),
                layout=dict(
                    height=300,
                    width=300))


#fig['layout'].update(scene=dict(aspectmode="data"))
fig.show()
#im.show()

In [253]:
#image to numpy array
from PIL import Image
import numpy
im = Image.open("input_images_testing/airplane_0002.png")
im = im.resize((320,320),resample=Image.NEAREST)
im.show()


# np_im = numpy.array(im)
# np_im.shape


# new_im = Image.fromarray(np_im)
# new_im.save("numpy_altered_sample2.png")


In [249]:
np_im[0].shape


(32, 3)

In [195]:
import plotly.graph_objects as go



show = [True,False,False,False,True,False,False,False]

widths = [.01,.1,.5,.8,1,2,3,10] 


layer_colors = ['rgba(31,119,180,', 
                'rgba(255,127,14,',
                'rgba(44,160,44,', 
                'rgba(214,39,40,',
                'rgba(39, 208, 214,', 
                'rgba(242, 250, 17,',
                'rgba(196, 94, 255,',
                'rgba(193, 245, 5,',
                'rgba(245, 85, 5,',
                'rgba(5, 165, 245,',
                'rgba(245, 5, 105,',
                'rgba(218, 232, 23,',
                'rgba(148, 23, 232,',
                'rgba(23, 232, 166,',]


edges_select_df = get_thresholded_edges(threshold = .1)


widths = []
for row in edges_select_df.itertuples():
    widths.append(row.rank_score)
    


fig2 = go.Figure()

# for i in range(4):
#     fig2.add_trace(go.Scatter(
#         x=[1+i, 2+i, 3+i],
#         y=[2+i, 2+i, 2+i],
#         legendgroup="group1",
#         name="group1",
#         mode="lines",
#         line=dict(color=layer_colors[i]+'1)',width = widths[i]),
#         text=[None, "Text B", None],
#         hoverinfo = 'text',
#         showlegend= show[i]
#     ))

# for i in range(4,8):
#     fig2.add_trace(go.Scatter(
#         x=[1+i, 2+i, 3+i],
#         y=[4+i, 9+i, 2+i],
#         legendgroup="group2",
#         name="group2",
#         mode="lines",
#         line=dict(color=layer_colors[i]+'1)',width = widths[i]),
#         showlegend = show[i]
#     ))


# fig3 = go.Figure()

# for i in range(len(layer_colors)):
#     fig3.add_trace(go.Scatter(
#         x=[1, 2, 3],
#         y=[i, i, i],
#         mode="lines",
#         line=dict(color=layer_colors[i]+'1)'),

#     ))    
  
fig4 = go.Figure()

def edge_width_scaling(x):
    return max(.5,(x*5)**1.5)

def edge_color_scaling(x):
    return max(.4,-(x-1)**4+1)

for i in range(len(widths)):
    fig4.add_trace(go.Scatter(
        x=[1, 2, 3],
        y=[.1*i, .1*i, .1*i],
        mode="lines",
        line=dict(color='rgba(255,0,0,%s)'%edge_color_scaling(widths[i]),width=edge_width_scaling(widths[i])),
        text = widths[i],
        hoverinfo = 'text'
        

    ))    

    
#fig2.show()
fig4.show()
edge_color_scaling(.4)

0.8704000000000001

In [86]:
# from PIL import Image
# im_names = os.listdir('input_images_testing')

# for name in im_names:
#     im = Image.open("input_images_testing/%s"%name)
#     im = im.resize((320,320),resample=Image.NEAREST)
#     im.save("input_images/%s"%name,"PNG")

for i in range(5,9):
    print(i)

5
6
7
8
