# Generate DAG

## Import packages

In [2]:
import json
import operator
import numpy as np
import networkx as nx
from time import time
import lucid.modelzoo.vision_models as models

## Functions to get general information of inception_v1 model

In [3]:
def get_layers(graph_nodes):
    '''
    Get all layers
    * input
        - graph_nodes: tensorflow graph nodes
    * output
        - layers: list of all layers
    '''
    layers = []
    for n in graph_nodes:
        node_name = n.name
        if node_name[-2:] == '_w':
            layer = node_name.split('_')[0]
            if layer not in layers:
                layers.append(layer)
    return layers

In [4]:
def get_channel_sizes(layer, weight_nodes):
    '''
    Get channel sizes
    * input
        - layer: the name of layer
        - weight_nodes: tensorflow nodes for all filters
    * output
        - channel_sizes: list of channel size for all pre-concatenated blocks
    '''
    
    channel_sizes = [get_shape_of_node(n)[0] for n in weight_nodes if layer in n.name and '_b' == n.name[-2:] and 'bottleneck' not in n.name]
    return channel_sizes

In [5]:
def get_shape_of_node(n):
    '''
    Get the shape of the tensorflow node
    * input
        - n: tensorflow node
    * output
        - tensor_shape: shape of n
    '''
    dims = n.attr['value'].tensor.tensor_shape.dim
    tensor_shape = [d.size for d in dims]
    return tensor_shape

In [6]:
def get_prev_layer(layer, mixed_layers):
    layer_idx = mixed_layers.index(layer)
    return mixed_layers[layer_idx - 1]

## Functions to extract influences from I-matrices for a class

In [7]:
def extract_class_I_matrices(Is, all_layers, start_layer, end_layer, pred_class, verbose=True):
    '''
    Extract influences for a class from I-matrices
    * input
        - Is: I-matrices for all class
        - all_layers: list of all layers
        - start_layer: start layer (towards output)
        - end_layer: end layer (towards input)
        - pred_class: predicted class
    * output
        - Is_class: I-matrices for a class. a dictionary, where
            - key: layer name (e.g. 'mixed4d', 'mixed4d_1')
            - val: influences for given layer(key), class(argument of the function)
    '''

    # Get layers starting from the given layer to the input layer
    start_idx, end_idx = all_layers.index(start_layer), all_layers.index(end_layer)
    target_layers = all_layers[start_idx: end_idx - 1: -1]
    
    Is_class = {}
    
    for layer in target_layers:
        if verbose:
            print('\n({}) loading {}'.format(pred_class, layer), end='')
        Is_class[layer] = Is[layer][pred_class]
        for branch in [1, 2]:
            inner_layer = '{}_{}'.format(layer, branch)
            if verbose:
                print(',', inner_layer, end='')
            Is_class[inner_layer] = Is[inner_layer][pred_class]
    if verbose:
        print('\n')
    
    return Is_class

In [8]:
def load_I_matrices(all_layers, start_layer, end_layer, I_mat_dirpath, verbose=True):
    '''
    Load I-matrices for all layers
    * input
        - all_layers: list of all layers
        - start_layer: start layer (towards output)
        - end_layer: end layer (towards input)
        - I_mat_dirpath: directory path of I-matrices
    * output
        - Is: I-matrices for all class
    '''
    
    # Get layers starting from the given layer to the input layer
    start_idx, end_idx = all_layers.index(start_layer), all_layers.index(end_layer)
    target_layers = all_layers[start_idx: end_idx - 1: -1]

    # Load I matrices
    Is = {}
    for layer in target_layers:
        if verbose:
            print('\n(all) loading {}'.format(layer), end='')
        Is[layer] = load_inf_matrix(I_mat_dirpath, layer)
        for branch in [1, 2]:
            inner_layer = '{}_{}'.format(layer, branch)
            if verbose:
                print(',', inner_layer, end='')
            Is[inner_layer] = load_inf_matrix(I_mat_dirpath, inner_layer)
    if verbose:
        print('\n')    
    return Is

In [9]:
def load_inf_matrix(I_mat_dirpath, layer):
    '''
    Load I matrix for a layer
    * input
        - mat_dirpath: directory path of I-matrices
        - layer: layer name
    * output
        - I_mat: I-matrix of the given layer
    '''
    if I_mat_dirpath[-1] == '/':
        filepath = I_mat_dirpath + 'I_' + layer + '.json'
    else:
        filepath = I_mat_dirpath + '/I_' + layer + '.json'
        
    with open(filepath) as f:
        I_mat = json.load(f)
    
    return I_mat

## Functions to extract M-matrices information

In [10]:
def read_M(M_mat_dirpath, layer):
    M = np.loadtxt(M_mat_dirpath + 'M-' + layer + '.csv', delimiter=',', dtype=int)
    return M

In [11]:
def read_Ms(M_mat_dirpath, mixed_layers):
    Ms = {}
    for layer in mixed_layers:
        M = read_M(M_mat_dirpath, layer)
        Ms[layer] = M
    return Ms

## Functions to query the influence values

In [12]:
def get_branch(layer, channel, layer_channels):
    '''
    Get branch of the channel in the layer
    * input
        - layer: the name of layer
        - channel: channel in the layer
        - layer_channels: fragment sizes of the layer
    * output
        - branch: branch of the channel
    '''
    
    channels = layer_channels[:]
    for i in range(len(channels) - 1):
        channels[i + 1] += channels[i]
        
    branch = np.searchsorted(channels, channel, side='right')
    
    return branch

In [13]:
def avg_num_of_prevs_for_a_channel(layer, Is_class):
    '''
    Get the average number of previous channels connected to a channel in the given layer
    * input
        - layer: layer
        - Is_class: I-matrices for a class
    * output
        - num_avg: the average number of connections for a channel in the given layer
    '''
    
    num_of_channel_edges = []
    
    for channel, prev_inf_dict in enumerate(Is_class[layer]):
        # Get branch
        branch = get_branch(layer, channel, layer_fragment_sizes[layer])
        
        if branch in [0, 3]:
            num_of_channel_edges.append(len(prev_inf_dict))
            
    num_avg = int(np.average(num_of_channel_edges))
    
    return num_avg

## Functions to generate the graph

In [14]:
def get_node_name(layer, channel):
    return layer + '-' + str(channel)

In [15]:
def gen_full_graph(Is_class, G, mixed_layers):
    
    # Add edges into G from Is_class
    for layer_idx, layer in enumerate(mixed_layers[::-1][:-1]):
        # Get previous layer
        prev_layer = mixed_layers[::-1][layer_idx + 1]
        
        # Get the average number of edges for a channel
        avg_num_edges = avg_num_of_prevs_for_a_channel(layer, Is_class)
        
        # For all channels in layer
        for channel, prev_inf_dict in enumerate(Is_class[layer]):
            # Get source node
            src = get_node_name(layer, channel)
            
            # Get branch
            branch = get_branch(layer, channel, layer_fragment_sizes[layer])
            
            # If the channel is connected to a branch
            if branch in [1, 2]:
                # Get possible edge weights for the channel
                channel_edges = {}
                for prev_channel in prev_inf_dict:
                    prev_inf = prev_inf_dict[prev_channel]
                    
                    # Extract influence information for prev_channel
                    branch_layer = '{}_{}'.format(layer, branch)
                    prev_prev_inf_dict = Is_class[branch_layer][int(prev_channel)]
                    
                    for prev_prev_channel in prev_prev_inf_dict:
                        prev_prev_inf = prev_prev_inf_dict[prev_prev_channel]
                        if prev_prev_channel not in channel_edges:
                            channel_edges[prev_prev_channel] = []
                        channel_edges[prev_prev_channel].append(min(prev_inf, prev_prev_inf))
                        
                # Get only one weight for each channel and prev_prev channel
                for prev_prev_channel in channel_edges:
                    channel_edges[prev_prev_channel] = max(channel_edges[prev_prev_channel])
                
                # Get top (avg_num_edges) prev_prev_channels based on the edge weight
                top_prev_prevs_weights = sorted(channel_edges.items(), key=operator.itemgetter(1), reverse=True)
                top_prev_prevs_weights = top_prev_prevs_weights[:avg_num_edges]
                
                # Add edges from channel and top_prev_prev_channel
                for prev_prev_channel, weight in top_prev_prevs_weights:
                    tgt = get_node_name(prev_layer, prev_prev_channel)
                    G.add_edge(src, tgt, weight=weight)
            
            # If the channel is directly connected to the previous layer
            elif branch in [0, 3]:
                for prev_channel in prev_inf_dict:
                    # Add edge of (src, tgt-prev)
                    prev_inf = prev_inf_dict[prev_channel]
                    tgt = get_node_name(prev_layer, prev_channel)
                    G.add_edge(src, tgt, weight=prev_inf)

In [16]:
def init_dag(mixed_layers):
    dag = {}
    for layer in mixed_layers[::-1]:
        dag[layer] = []
    return dag

## Functions for pagerank

In [17]:
def get_personalization_dict(G, Ms, mixed_layers, pred_class):
    '''
    Get personalization dictionary
    * input
        - G: graph
        - Ms: M matrices
        - mixed_layers: layers starting with 'mixed'
    '''
    
    personalization = {node: 1 for node in list(G.nodes)}

    for layer in mixed_layers[::-1]:
        M = Ms[layer][pred_class]
        max_m = max(M)
        for channel in range(M.shape[-1]):
            node = layer + '-' + str(channel)
            if node in personalization:
                personalization[get_node_name(layer, channel)] = M[channel] / max_m
    
    return personalization

## Functions for thresholding nodes, edges

In [18]:
def get_prob_mass(prob_mass_threshold, reverse_sorted_vals):
    prob_mass = 0
    threshold_cnt = 0
    while prob_mass < prob_mass_threshold:
        prob_mass += reverse_sorted_vals[threshold_cnt]
        threshold_cnt += 1
    threshold_val = reverse_sorted_vals[threshold_cnt]
    return threshold_cnt, threshold_val

In [19]:
def get_threshold(mixed_layers, pagerank, prob_mass_threshold=0.12):
    # Get threshold value
    pagerank_values = list(pagerank.values())
    sorted_pagerank_vals = sorted(pagerank_values, reverse=True)
    threshold_cnt, threshold_val = get_prob_mass(prob_mass_threshold, sorted_pagerank_vals)
    
    # Get thresholds
    thresholds = {}
    pagerank_sorted = sorted(pagerank.items(), key=operator.itemgetter(1))
    for layer in mixed_layers[::-1]:
        pageranks_layer = list(filter(lambda x: layer in x[0],  pagerank_sorted))
        pagerank_values_layer = list(map(lambda x: x[1], pageranks_layer))
        threshold = len(pagerank_values_layer) - np.searchsorted(np.array(pagerank_values_layer), threshold_val)
        thresholds[layer] = max(min(threshold, len(pagerank_values_layer) - 1), 0)

    return thresholds

In [20]:
def get_threshold_val(mixed_layers, pagerank, prob_mass_threshold=0.12):
    pagerank_values = list(pagerank.values())
    sorted_pagerank_vals = sorted(pagerank_values, reverse=True)
    threshold_cnt, threshold_val = get_prob_mass(prob_mass_threshold, sorted_pagerank_vals)
    
    return threshold_val

In [21]:
def get_thresholded_nodes(pagerank, threshold_val):
    thresholded_nodes = {}
    
    for node in pagerank:
        if pagerank[node] > threshold_val:
            thresholded_nodes[node] = pagerank[node]
        
    return thresholded_nodes

In [22]:
def get_thresholded_edges(mixed_layers, G, thresholded_nodes):
    '''
    Get thresholded edges
    * input
        - mixed_layers: all mixed layers
        - G: graph
        - thresholded_nodes: nodes whose pagerank value is higher than threshold
    * output
        - thresholded_edges: edges connected by both thresholded nodes
    '''
    
    # Initialize thresholded_edges
    thresholded_edges = {}
    edge_checker = set()
    for layer in mixed_layers[::-1]:
        thresholded_edges[layer] = {}

    for node in thresholded_nodes:
        for edge in G.edges(node):
            node1, node2 = edge
            if (node1 not in thresholded_nodes) or (node2 not in thresholded_nodes):
                continue

            layer, channel, prev_layer, prev_channel = parse_edge(edge, mixed_layers)
            if channel not in thresholded_edges[layer]:
                thresholded_edges[layer][channel] = []

            if (node1, node2) in edge_checker:
                continue
            elif (node2, node1) in edge_checker:
                continue
            edge_checker.add((node1, node2))

            thresholded_edges[layer][channel].append({
                'prev_channel': int(prev_channel),
                'inf': G.get_edge_data(*edge)['weight']
            })

    return thresholded_edges

In [37]:
def get_prev_nodes_with_valid_edges(thresholded_edges, thresholded_nodes, prev_layer):
    if prev_layer == 'mixed3a':
        mixed3a_nodes = [x.split('-')[1] for x in thresholded_nodes if 'mixed3a' in x]
        return mixed3a_nodes
    else:
        return list(thresholded_edges[prev_layer].keys())

## Functions to generate DAG

In [24]:
def parse_edge(edge, mixed_layers):
    n1_layer, n1_channel = edge[0].split('-')
    n2_layer, n2_channel = edge[1].split('-')
    
    n1_idx, n2_idx = mixed_layers.index(n1_layer), mixed_layers.index(n2_layer)
    
    # If n1 is current layer, n2 is previous layer
    if n1_idx > n2_idx:
        layer, channel = n1_layer, n1_channel
        prev_layer, prev_channel = n2_layer, n2_channel
        
    # If n1 is previous layer, n1 is current layer
    else:
        layer, channel = n2_layer, n2_channel
        prev_layer, prev_channel = n1_layer, n1_channel
    
    return layer, channel, prev_layer, prev_channel

In [43]:
def gen_dag(mixed_layers, thresholded_nodes, thresholded_edges, pagerank, Is, pred_class):

    # Initialize dag, check_channel, layer_validity
    dag = {}
    check_channel = {}
    layer_validity = {}
    for layer in mixed_layers[::-1]:
        dag[layer] = []
        check_channel[layer] = set()
        if layer != 'mixed3a':
            layer_validity[layer] = False
    
    # Mixed3a
    layer = 'mixed3a'
    for channel, cnt in enumerate(Ms['mixed3a'][pred_class]):
        node_name = get_node_name(layer, channel)
        if node_name in thresholded_nodes:
            dag[layer].append({
                'channel': int(channel),
                'count': int(cnt),
                'layer': layer,
                'pagerank': pagerank[node_name],
                'prev_channels': [],
                'attr_channels': []
            })
            
    # Other layers
    for node in thresholded_nodes:
        for edge in G.edges(node):
            
            # Parse the edge
            node1, node2 = edge
            layer, channel, prev_layer, prev_channel = parse_edge(edge, mixed_layers)
            
            curr_node = get_node_name(layer, channel)
            prev_node = get_node_name(prev_layer, prev_channel)
            
            # Ignore unnecessary cases
            if curr_node not in thresholded_nodes:
                continue
            if channel in check_channel[layer]:
                continue
            check_channel[layer].add(channel)
            
            # Read M matrices
            M = Ms[layer]

            # Get attributed previous channels
            attr_channels_dict = Is[layer][pred_class][int(channel)]
            attr_channels_dict = sorted(attr_channels_dict.items(), key=lambda x:x[1], reverse=True)[0:3]
            attr_channels = [{'prev_channel': prev[0], 'inf': prev[1]} for prev in attr_channels_dict]

            # If the channel is connected to thresholded nodes in previous layer
            if channel in thresholded_edges[layer]:
                # Mark the layer is valid
                layer_validity[layer] = True
            
                # Get previous channels    
                valid_prev_nodes = get_prev_nodes_with_valid_edges(thresholded_edges, thresholded_nodes, prev_layer)
                filtered_prev_channels = list(filter(lambda x: str(x['prev_channel']) in valid_prev_nodes, thresholded_edges[layer][channel]))

                # Add node into dag
                dag[layer].append({
                    'channel': int(channel),
                    'count': int(M[pred_class][int(channel)]),
                    'layer': layer,
                    'pagerank': pagerank[curr_node],
                    'prev_channels': filtered_prev_channels,
                    'attr_channels': attr_channels
                })
            
            # If the channel is not connected to thresholded nodes in previous layer
            else:
                # Add node into dag
                dag[layer].append({
                    'channel': int(channel),
                    'count': int(M[pred_class][int(channel)]),
                    'layer': layer,
                    'pagerank': pagerank[curr_node],
                    'prev_channels': [],
                    'attr_channels': attr_channels
                })

    if False in layer_validity.values():
        need_to_relax = True
        return None, need_to_relax
    else:
        need_to_relax = False
        return dag, need_to_relax

## Get inception_v1 model infromation

In [26]:
# data_dirpath = '/Users/haekyu/data/summit/'
data_dirpath = '/home/fred/code/summit-notebooks/data/'
imgnet_dirpath = data_dirpath
I_mat_dirpath = data_dirpath + 'I-matrices/'
M_mat_dirpath = data_dirpath + 'M-matrices/'
dag_dirpath = data_dirpath + 'dag/'

In [27]:
googlenet = models.InceptionV1()
googlenet.load_graphdef()
nodes = googlenet.graph_def.node

In [28]:
all_layers = get_layers(nodes)
mixed_layers = [layer for layer in all_layers if 'mixed' in layer]
layer_fragment_sizes = {layer: get_channel_sizes(layer, nodes) for layer in mixed_layers}

In [29]:
with open(imgnet_dirpath + 'imagenet.json') as f:
    imgnet = json.load(f)

## Run for all classes

In [30]:
start_layer = 'mixed5b'
end_layer = 'mixed3a'

In [31]:
Is = load_I_matrices(all_layers, start_layer, end_layer, I_mat_dirpath, verbose=True)


(all) loading mixed5b, mixed5b_1, mixed5b_2
(all) loading mixed5a, mixed5a_1, mixed5a_2
(all) loading mixed4e, mixed4e_1, mixed4e_2
(all) loading mixed4d, mixed4d_1, mixed4d_2
(all) loading mixed4c, mixed4c_1, mixed4c_2
(all) loading mixed4b, mixed4b_1, mixed4b_2
(all) loading mixed4a, mixed4a_1, mixed4a_2
(all) loading mixed3b, mixed3b_1, mixed3b_2
(all) loading mixed3a, mixed3a_1, mixed3a_2



In [None]:
num_class = 1000
prob_mass_dict = {}

for pred_class in range(num_class):
    prob_mass_threshold = 0.1
    prob_mass_increase = 0.01
    
#     if pred_class not in [0,55,270]:
#         continue
        
    tic = time()
    
    # Extract influence information for the pred_class
    Is_class = extract_class_I_matrices(Is, all_layers, start_layer, end_layer, pred_class, verbose=False)
    
    # Initialize an undirected graph
    G = nx.Graph()
    
    # Generate full graph
    gen_full_graph(Is_class, G, mixed_layers)
    
    # Read M-matrices
    Ms = read_Ms(M_mat_dirpath, mixed_layers)
    
    # Personalized pagerank to filter nodes
    personalization = get_personalization_dict(G, Ms, mixed_layers, pred_class)
    pagerank = nx.pagerank(G, personalization=personalization, weight='weight', alpha=0.85)
    
    need_to_relax = True

    while need_to_relax:
        
        # Thresolding
        threshold_val = get_threshold_val(mixed_layers, pagerank, prob_mass_threshold=prob_mass_threshold)
        thresholds = get_threshold(mixed_layers, pagerank, prob_mass_threshold=prob_mass_threshold)
        thresholded_nodes = get_thresholded_nodes(pagerank, threshold_val)
        thresholded_edges = get_thresholded_edges(mixed_layers, G, thresholded_nodes)
        print('prob mass threshold:{}, threshold_val:{}'.format(prob_mass_threshold, threshold_val))

        # Generate dag in json format
        dag, need_to_relax = gen_dag(mixed_layers, thresholded_nodes, thresholded_edges, pagerank, Is, pred_class)
        
        if need_to_relax:
            prob_mass_threshold += prob_mass_increase
    
    # Save prob_mass_threshold
    prob_mass_dict[pred_class] = prob_mass_threshold
    
    # Save the graph into a file
    filename = dag_dirpath + 'pagerank/' + 'dag-{}.json'.format(pred_class)
    with open(filename, 'w') as f:
        json.dump(dag, f, indent=2)
        
    toc = time()
                    
    print('class: %s, time: %.2lf sec' % (pred_class, toc - tic))
    print(nx.info(G))
    
filename = dag_dirpath + 'pagerank/' + 'prob-mass-threshold.json'
with open(filename, 'w') as f:
    json.dump(prob_mass_dict, f, indent=2)

prob mass threshold:0.1, threshold_val:0.002362688496897321
prob mass threshold:0.11, threshold_val:0.0020794432847662924
prob mass threshold:0.12, threshold_val:0.0019657799161241765
prob mass threshold:0.13, threshold_val:0.0018018241405797911
prob mass threshold:0.14, threshold_val:0.0016184143138652995
prob mass threshold:0.15000000000000002, threshold_val:0.0015364145778685894
class: 0, time: 29.69 sec
Name: 
Type: Graph
Number of nodes: 5480
Number of edges: 356579
Average degree: 130.1383
prob mass threshold:0.1, threshold_val:0.0013534423904130919
class: 1, time: 30.89 sec
Name: 
Type: Graph
Number of nodes: 5481
Number of edges: 373696
Average degree: 136.3605
prob mass threshold:0.1, threshold_val:0.0014020358859817533
class: 2, time: 30.68 sec
Name: 
Type: Graph
Number of nodes: 5482
Number of edges: 391430
Average degree: 142.8055
prob mass threshold:0.1, threshold_val:0.0017848757843145517
class: 3, time: 28.99 sec
Name: 
Type: Graph
Number of nodes: 5482
Number of edges: 