# Generate Attribution Graphs

## Import packages

In [None]:
import json
import operator
import numpy as np
import networkx as nx
from time import time
import lucid.modelzoo.vision_models as models
import matplotlib.pyplot as plt

## Functions to get general information of inception_v1 model

In [None]:
def get_layers(graph_nodes):
    '''
    Get all layers
    * input
        - graph_nodes: tensorflow graph nodes
    * output
        - layers: list of all layers
    '''
    layers = []
    for n in graph_nodes:
        node_name = n.name
        if node_name[-2:] == '_w':
            layer = node_name.split('_')[0]
            if layer not in layers:
                layers.append(layer)
    return layers

In [None]:
def get_channel_sizes(layer, weight_nodes):
    '''
    Get channel sizes
    * input
        - layer: the name of layer
        - weight_nodes: tensorflow nodes for all filters
    * output
        - channel_sizes: list of channel size for all pre-concatenated blocks
    '''
    
    channel_sizes = [get_shape_of_node(n)[0] for n in weight_nodes if layer in n.name and '_b' == n.name[-2:] and 'bottleneck' not in n.name]
    return channel_sizes

In [None]:
def get_shape_of_node(n):
    '''
    Get the shape of the tensorflow node
    * input
        - n: tensorflow node
    * output
        - tensor_shape: shape of n
    '''
    dims = n.attr['value'].tensor.tensor_shape.dim
    tensor_shape = [d.size for d in dims]
    return tensor_shape

In [None]:
def get_prev_layer(layer, mixed_layers):
    layer_idx = mixed_layers.index(layer)
    return mixed_layers[layer_idx - 1]

## Functions to extract influences from I-matrices for a class

In [None]:
def extract_class_I_matrices(Is, all_layers, start_layer, end_layer, pred_class, verbose=True):
    '''
    Extract influences for a class from I-matrices
    * input
        - Is: I-matrices for all class
        - all_layers: list of all layers
        - start_layer: start layer (towards output)
        - end_layer: end layer (towards input)
        - pred_class: predicted class
    * output
        - Is_class: I-matrices for a class. a dictionary, where
            - key: layer name (e.g. 'mixed4d', 'mixed4d_1')
            - val: influences for given layer(key), class(argument of the function)
    '''

    # Get layers starting from the given layer to the input layer
    start_idx, end_idx = all_layers.index(start_layer), all_layers.index(end_layer)
    target_layers = all_layers[start_idx: end_idx - 1: -1]
    
    Is_class = {}
    
    for layer in target_layers:
        if verbose:
            print('\n({}) loading {}'.format(pred_class, layer), end='')
        Is_class[layer] = Is[layer][pred_class]
        for branch in [1, 2]:
            inner_layer = '{}_{}'.format(layer, branch)
            if verbose:
                print(',', inner_layer, end='')
            Is_class[inner_layer] = Is[inner_layer][pred_class]
    if verbose:
        print('\n')
    
    return Is_class

In [None]:
def load_I_matrices(all_layers, start_layer, end_layer, I_mat_dirpath, verbose=True):
    '''
    Load I-matrices for all layers
    * input
        - all_layers: list of all layers
        - start_layer: start layer (towards output)
        - end_layer: end layer (towards input)
        - I_mat_dirpath: directory path of I-matrices
    * output
        - Is: I-matrices for all class
    '''
    
    # Get layers starting from the given layer to the input layer
    start_idx, end_idx = all_layers.index(start_layer), all_layers.index(end_layer)
    target_layers = all_layers[start_idx: end_idx - 1: -1]

    # Load I matrices
    Is = {}
    for layer in target_layers:
        if verbose:
            print('\n(all) loading {}'.format(layer), end='')
        Is[layer] = load_inf_matrix(I_mat_dirpath, layer)
        for branch in [1, 2]:
            inner_layer = '{}_{}'.format(layer, branch)
            if verbose:
                print(',', inner_layer, end='')
            Is[inner_layer] = load_inf_matrix(I_mat_dirpath, inner_layer)
    if verbose:
        print('\n')    
    return Is

In [None]:
def load_inf_matrix(I_mat_dirpath, layer):
    '''
    Load I matrix for a layer
    * input
        - mat_dirpath: directory path of I-matrices
        - layer: layer name
    * output
        - I_mat: I-matrix of the given layer
    '''
    if I_mat_dirpath[-1] == '/':
        filepath = I_mat_dirpath + 'I_' + layer + '.json'
    else:
        filepath = I_mat_dirpath + '/I_' + layer + '.json'
        
    with open(filepath) as f:
        I_mat = json.load(f)
    
    return I_mat

## Functions to extract A matrix information

In [None]:
def read_A(A_mat_dirpath, layer, A_prob_threshold):
    A = np.loadtxt(A_mat_dirpath + 'A-' + str(A_prob_threshold) + '-' + layer + '.csv', delimiter=',', dtype=int)
    return A

In [None]:
def read_As(A_mat_dirpath, mixed_layers, A_prob_threshold):
    As = {}
    for layer in mixed_layers:
        A = read_A(A_mat_dirpath, layer, A_prob_threshold)
        As[layer] = A
    return As

## Functions to query the influence values

In [None]:
def get_branch(layer, channel, layer_channels):
    '''
    Get branch of the channel in the layer
    * input
        - layer: the name of layer
        - channel: channel in the layer
        - layer_channels: fragment sizes of the layer
    * output
        - branch: branch of the channel
    '''
    
    channels = layer_channels[:]
    for i in range(len(channels) - 1):
        channels[i + 1] += channels[i]
        
    branch = np.searchsorted(channels, channel, side='right')
    
    return branch

In [None]:
def avg_num_of_prevs_for_a_channel(layer, Is_class):
    '''
    Get the average number of previous channels connected to a channel in the given layer
    * input
        - layer: layer
        - Is_class: I-matrices for a class
    * output
        - num_avg: the average number of connections for a channel in the given layer
    '''
    
    num_of_channel_edges = []
    
    for channel, prev_inf_dict in enumerate(Is_class[layer]):
        # Get branch
        branch = get_branch(layer, channel, layer_fragment_sizes[layer])
        
        if branch in [0, 3]:
            num_of_channel_edges.append(len(prev_inf_dict))
            
    num_avg = int(np.average(num_of_channel_edges))
    
    return num_avg

## Functions to generate the graph

In [None]:
def get_node_name(layer, channel):
    return layer + '-' + str(channel)

In [None]:
def gen_full_graph(Is_class, G, mixed_layers, outlier_nodes):
    
    # Add edges into G from Is_class
    for layer_idx, layer in enumerate(mixed_layers[::-1][:-1]):
        # Get previous layer
        prev_layer = mixed_layers[::-1][layer_idx + 1]
        
        # Get the average number of edges for a channel
        avg_num_edges = avg_num_of_prevs_for_a_channel(layer, Is_class)
        
        # For all channels in layer
        for channel, prev_inf_dict in enumerate(Is_class[layer]):
            # Get source node
            src = get_node_name(layer, channel)
            
            # Get branch
            branch = get_branch(layer, channel, layer_fragment_sizes[layer])
            
            # If the channel is connected to a branch
            if branch in [1, 2]:
                # Get possible edge weights for the channel
                channel_edges = {}
                for prev_channel in prev_inf_dict:
                    prev_inf = prev_inf_dict[prev_channel]
                    
                    # Extract influence information for prev_channel
                    branch_layer = '{}_{}'.format(layer, branch)
                    prev_prev_inf_dict = Is_class[branch_layer][int(prev_channel)]
                    
                    for prev_prev_channel in prev_prev_inf_dict:
                        prev_prev_inf = prev_prev_inf_dict[prev_prev_channel]
                        if prev_prev_channel not in channel_edges:
                            channel_edges[prev_prev_channel] = []
                        channel_edges[prev_prev_channel].append(min(prev_inf, prev_prev_inf))
                        
                # Get only one weight for each channel and prev_prev channel
                for prev_prev_channel in channel_edges:
                    channel_edges[prev_prev_channel] = max(channel_edges[prev_prev_channel])
                
                # Get top (avg_num_edges) prev_prev_channels based on the edge weight
                top_prev_prevs_weights = sorted(channel_edges.items(), key=operator.itemgetter(1), reverse=True)
                top_prev_prevs_weights = top_prev_prevs_weights[:avg_num_edges]
                
                # Add edges from channel and top_prev_prev_channel
                for prev_prev_channel, weight in top_prev_prevs_weights:
                    tgt = get_node_name(prev_layer, prev_prev_channel)
                    G.add_edge(src, tgt, weight=weight)
            
            # If the channel is directly connected to the previous layer
            elif branch in [0, 3]:
                for prev_channel in prev_inf_dict:
                    # Add edge of (src, tgt-prev)
                    prev_inf = prev_inf_dict[prev_channel]
                    tgt = get_node_name(prev_layer, prev_channel)
                    G.add_edge(src, tgt, weight=prev_inf)
                    
    # remove outlier nodes and edges from full graph
    for outlier in outlier_nodes:
        G.remove_node(outlier)


In [None]:
def init_dag(mixed_layers):
    dag = {}
    for layer in mixed_layers[::-1]:
        dag[layer] = []
    return dag

## Functions for pagerank

In [None]:
def get_personalization_dict(G, As, mixed_layers, pred_class, outlier_nodes):
    '''
    Get personalization dictionary
    * input
        - G: graph
        - As: A matrices
        - mixed_layers: layers starting with 'mixed'
    '''
    
    personalization = {node: 1 for node in list(G.nodes)}

    for layer in mixed_layers[::-1]:
        A = As[layer][pred_class]
        max_a = -100
        for channel, val in enumerate(A):
            node = get_node_name(layer, channel)
            if node not in outlier_nodes:
                max_a = max(max_a, val)
                
        for channel in range(A.shape[-1]):
            node = layer + '-' + str(channel)
            if node in personalization:
                personalization[get_node_name(layer, channel)] = A[channel] / max_a
    
    return personalization

## Functions for thresholding nodes, edges

In [None]:
def get_prob_mass(prob_mass_threshold, reverse_sorted_vals):

    prob_mass = 0
    threshold_cnt = 0
    sum_val = np.sum(reverse_sorted_vals)
    
    while prob_mass < prob_mass_threshold:
        prob_mass += reverse_sorted_vals[threshold_cnt] / sum_val
        threshold_cnt += 1

    threshold_val = reverse_sorted_vals[threshold_cnt]
    return threshold_cnt, threshold_val

In [None]:
def get_threshold(mixed_layers, pagerank, prob_mass_thresholds, unified_threshold=True):
    '''
    Get threshold dictionary
    * input
        - mixed_layers: layers starting with mixed
        - pagerank: pagerank dictionary
        - prob_mass_thresholds: probability mass thresholds for each layer
        - unified_threshold: whether the same threshold is used for all layers
    * output
        - thresholds: threshold dictionary, whose 
            - key: layer
            - val: threshold criteria value
    '''
    
    thresholds = {}
    prob_mass_threshold = list(prob_mass_thresholds.values())[0]
    
    if unified_threshold:
        # Get threshold value for all layer
        threshold_val = get_threshold_val(mixed_layers, pagerank, prob_mass_threshold=prob_mass_threshold)
    
        # Get thresholds
        for layer in mixed_layers[::-1]:
            thresholds[layer] = threshold_val
    
    else:
        for layer in mixed_layers[::-1]:
            pagerank_layer = {node: pagerank[node] for node in pagerank if node.split('-')[0] == layer}
            threshold_val_layer = get_threshold_val(mixed_layers, pagerank_layer, prob_mass_threshold=prob_mass_thresholds[layer])
            thresholds[layer] = threshold_val_layer

    return thresholds

In [None]:
def get_threshold_val(mixed_layers, pagerank, prob_mass_threshold=0.12):
    pagerank_values = list(pagerank.values())
    sorted_pagerank_vals = sorted(pagerank_values, reverse=True)
    threshold_cnt, threshold_val = get_prob_mass(prob_mass_threshold, sorted_pagerank_vals)
    
    return threshold_val

In [None]:
def get_thresholded_nodes(pagerank, thresholds, outlier_nodes):
    thresholded_nodes = {}
    for node in pagerank:
        if node in outlier_nodes:
            continue
            
        layer, channel = node.split('-')
        if pagerank[node] > thresholds[layer]:
            thresholded_nodes[node] = pagerank[node]
        
    return thresholded_nodes

In [None]:
def get_thresholded_edges(mixed_layers, G, thresholded_nodes):
    '''
    Get thresholded edges
    * input
        - mixed_layers: all mixed layers
        - G: graph
        - thresholded_nodes: nodes whose pagerank value is higher than threshold
    * output
        - thresholded_edges: edges connected by both thresholded nodes
    '''
    
    # Initialize thresholded_edges
    thresholded_edges = {}
    edge_checker = set()
    for layer in mixed_layers[::-1]:
        thresholded_edges[layer] = {}

    for node in thresholded_nodes:
        for edge in G.edges(node):
            node1, node2 = edge
            if (node1 not in thresholded_nodes) or (node2 not in thresholded_nodes):
                continue

            layer, channel, prev_layer, prev_channel = parse_edge(edge, mixed_layers)
            if channel not in thresholded_edges[layer]:
                thresholded_edges[layer][channel] = []

            if (node1, node2) in edge_checker:
                continue
            elif (node2, node1) in edge_checker:
                continue
            edge_checker.add((node1, node2))

            thresholded_edges[layer][channel].append({
                'prev_channel': int(prev_channel),
                'inf': G.get_edge_data(*edge)['weight']
            })

    return thresholded_edges

In [None]:
def get_prev_nodes_with_valid_edges(thresholded_edges, thresholded_nodes, prev_layer):
    if prev_layer == 'mixed3a':
        mixed3a_nodes = [x.split('-')[1] for x in thresholded_nodes if 'mixed3a' in x]
        return mixed3a_nodes
    else:
        return list(thresholded_edges[prev_layer].keys())

## Functions to generate DAG

In [None]:
def parse_edge(edge, mixed_layers):
    n1_layer, n1_channel = edge[0].split('-')
    n2_layer, n2_channel = edge[1].split('-')
    
    n1_idx, n2_idx = mixed_layers.index(n1_layer), mixed_layers.index(n2_layer)
    
    # If n1 is current layer, n2 is previous layer
    if n1_idx > n2_idx:
        layer, channel = n1_layer, n1_channel
        prev_layer, prev_channel = n2_layer, n2_channel
        
    # If n1 is previous layer, n1 is current layer
    else:
        layer, channel = n2_layer, n2_channel
        prev_layer, prev_channel = n1_layer, n1_channel
    
    return layer, channel, prev_layer, prev_channel

In [None]:
def gen_dag(mixed_layers, thresholded_nodes, thresholded_edges, pagerank, As, Is, pred_class):
    
    # Initialize dag, check_channel, layer_validity
    dag = {}
    check_channel = {}
    layer_validity = {}
    for layer in mixed_layers[::-1]:
        dag[layer] = []
        check_channel[layer] = set()
        if layer != 'mixed3a':
            layer_validity[layer] = False
    
    # Mixed3a
    layer = 'mixed3a'
    for channel, cnt in enumerate(As['mixed3a'][pred_class]):
        node_name = get_node_name(layer, channel)
        if node_name in thresholded_nodes:
            dag[layer].append({
                'channel': int(channel),
                'count': int(cnt),
                'layer': layer,
                'pagerank': pagerank[node_name],
                'prev_channels': [],
                'attr_channels': []
            })
            
    # Other layers
    for node in thresholded_nodes:
        for edge in G.edges(node):
            
            # Parse the edge
            node1, node2 = edge
            layer, channel, prev_layer, prev_channel = parse_edge(edge, mixed_layers)
            
            curr_node = get_node_name(layer, channel)
            prev_node = get_node_name(prev_layer, prev_channel)
            
            # Ignore unnecessary cases
            if curr_node not in thresholded_nodes:
                continue
            if channel in check_channel[layer]:
                continue
            check_channel[layer].add(channel)
            
            # Read A matrices
            A = As[layer]

            # Get attributed previous channels
            attr_channels_dict = Is[layer][pred_class][int(channel)]
            attr_channels_dict = sorted(attr_channels_dict.items(), key=lambda x:x[1], reverse=True)[0:3]
            attr_channels = [{'prev_channel': prev[0], 'inf': prev[1]} for prev in attr_channels_dict]

            # If the channel is connected to thresholded nodes in previous layer
            if channel in thresholded_edges[layer]:
                # Mark the layer is valid
                layer_validity[layer] = True
            
                # Get previous channels    
                valid_prev_nodes = get_prev_nodes_with_valid_edges(thresholded_edges, thresholded_nodes, prev_layer)
                filtered_prev_channels = list(filter(lambda x: str(x['prev_channel']) in valid_prev_nodes, thresholded_edges[layer][channel]))

                # Add node into dag
                dag[layer].append({
                    'channel': int(channel),
                    'count': int(A[pred_class][int(channel)]),
                    'layer': layer,
                    'pagerank': pagerank[curr_node],
                    'prev_channels': filtered_prev_channels,
                    'attr_channels': attr_channels
                })
            
            # If the channel is not connected to thresholded nodes in previous layer
            else:
                # Add node into dag
                dag[layer].append({
                    'channel': int(channel),
                    'count': int(A[pred_class][int(channel)]),
                    'layer': layer,
                    'pagerank': pagerank[curr_node],
                    'prev_channels': [],
                    'attr_channels': attr_channels
                })

    return dag, layer_validity

## Get inception_v1 model infromation

In [None]:
# data_dirpath = '/Users/haekyu/data/summit/'
# data_dirpath = '/home/fred/code/summit-notebooks/data/'
data_dirpath = '/Users/fredhohman/Github/summit-notebooks/data/'
imgnet_dirpath = data_dirpath
I_mat_dirpath = data_dirpath + 'I-matrices/'
A_mat_dirpath = data_dirpath + 'A-matrices/'
dag_dirpath = data_dirpath + 'ag/'

In [None]:
googlenet = models.InceptionV1()
googlenet.load_graphdef()
nodes = googlenet.graph_def.node

In [None]:
all_layers = get_layers(nodes)
mixed_layers = [layer for layer in all_layers if 'mixed' in layer]
layer_fragment_sizes = {layer: get_channel_sizes(layer, nodes) for layer in mixed_layers}

In [None]:
A_prob_threshold = 0.03
with open(imgnet_dirpath + 'imagenet-' + str(A_prob_threshold) + '.json') as f:
    imgnet = json.load(f)

## Run for all classes

In [None]:
start_layer = 'mixed5b'
end_layer = 'mixed3a'

In [None]:
Is = load_I_matrices(all_layers, start_layer, end_layer, I_mat_dirpath, verbose=True)

In [None]:
outlier_nodes = ['mixed3a-67', 'mixed3a-190', 'mixed3b-390', 'mixed3b-399', 'mixed3b-412']

In [None]:
num_class = 1000
prob_mass_dict = {}
unified_threshold = True
plot_pagerank_values = False
test_class_for_debugging = True

# start at 7.5%, increase by 0.5% if no connecetions/channels in a layer
prob_mass_initial = 0.075
prob_mass_increase = 0.005

for pred_class in range(num_class):
    prob_mass_thresholds = {layer: prob_mass_initial for layer in mixed_layers}

    if test_class_for_debugging:
        if pred_class not in [55,270,294]: 
            continue
        
    tic = time()
    
    # Extract influence information for the pred_class
    Is_class = extract_class_I_matrices(Is, all_layers, start_layer, end_layer, pred_class, verbose=False)
    
    # Initialize an undirected graph
    G = nx.Graph()
    
    # Generate full graph
    gen_full_graph(Is_class, G, mixed_layers, outlier_nodes)
    
    # Read A matrices
    As = read_As(A_mat_dirpath, mixed_layers, A_prob_threshold)
    
    # Personalized pagerank to filter nodes
    personalization = get_personalization_dict(G, As, mixed_layers, pred_class, outlier_nodes)
    pagerank = nx.pagerank(G, personalization=personalization, weight='weight', alpha=0.85)
    
    if plot_pagerank_values:

        for layer in mixed_layers:
            print(layer)

            pagerank_values = []
            for node in pagerank:
                if node.split('-')[0] == layer:
                    pagerank_values.append(pagerank[node])

            print(len(pagerank_values))
            plt.figure(figsize=(12,3))
            plt.hist(pagerank_values, bins=100)
            plt.show()
    
    need_to_relax = True

    while need_to_relax:
        
        # Thresolding
        thresholds = get_threshold(mixed_layers, pagerank, prob_mass_thresholds=prob_mass_thresholds, unified_threshold=unified_threshold)
        thresholded_nodes = get_thresholded_nodes(pagerank, thresholds, outlier_nodes=outlier_nodes)
        thresholded_edges = get_thresholded_edges(mixed_layers, G, thresholded_nodes)

        # Generate dag in json format
        dag, layer_validity = gen_dag(mixed_layers, thresholded_nodes, thresholded_edges, pagerank, As, Is, pred_class)
        
        need_to_relax = False in layer_validity.values()
        if need_to_relax:
            if unified_threshold:
                for layer in mixed_layers:
                    prob_mass_thresholds[layer] += prob_mass_increase
            else:
                for layer_idx, layer in enumerate(mixed_layers[1:]):
                    if not layer_validity[layer]:
                        prev_layer = mixed_layers[layer_idx]
                        prob_mass_thresholds[prev_layer] += prob_mass_increase
                        prob_mass_thresholds[layer] += prob_mass_increase

    # Save prob_mass_threshold
    prob_mass_dict[pred_class] = prob_mass_thresholds
    
    # Save the graph into a file
    filename = dag_dirpath + 'ag-{}.json'.format(pred_class)
    with open(filename, 'w') as f:
        json.dump(dag, f, indent=2)

    toc = time()

    print('class: %s, time: %.2lf sec' % (pred_class, toc - tic))
#     print(nx.info(G))
    
filename = dag_dirpath + 'pagerank/' + 'prob-mass-threshold-{}.json'.format('unified' if unified_threshold else 'separate')
with open(filename, 'w') as f:
    json.dump(prob_mass_dict, f)

In [None]:
from collections import defaultdict
pg_dict = defaultdict(lambda: {'min': 100, 'max': -100})
for node in pagerank:
    if node in outlier_nodes:
        continue
    layer, channel = node.split('-')
    val = pagerank[node]
    pg_dict[layer] = {'min': min(pg_dict[layer]['min'], val), 'max': max(pg_dict[layer]['max'], val)}

In [None]:
for layer in pg_dict:
    a, A = pg_dict[layer]['min'], pg_dict[layer]['max']
    print('%s, min:%.4lf, max:%.4lf' % (layer, a, A))