In [None]:
from re import compile, search

from matplotlib import pyplot as plt
import networkx as nx

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm

from mnist import MNIST
from model import PrototypeGraph
from main import load_model

%matplotlib inline 

class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

        
args = Namespace()

args.data = './data'
args.merge_targets = None
args.batch_size = 50
args.no_root = False
args.use_representation = False

# args.resume = 'results/checkpoint/nodes-8_leaves-10_N-10_classifier-50_seed-7557_2021-03-12_223433/best_model.pth'
# args.resume = 'results/checkpoint/nodes-16_leaves-10_N-9_classifier-50_seed-2379_2021-03-12_230814/best_model.pth'
# args.resume = 'results/checkpoint/nodes-16_leaves-10_N-6_classifier-50_seed-2457_2021-03-12_225115/best_model.pth'
# args.resume = 'results_nodes-128_N-32_times-5/checkpoint/nodes-128_leaves-10_N-32_classifier-50_seed-1559_temp-0.25-linear_2021-03-18_202936/best_model.pth'
args.resume = 'results_nodes-128_N-32_times-5/checkpoint/nodes-128_leaves-10_N-32_classifier-784_seed-1238_temp-None-None_2021-03-19_022925/best_model.pth'

p = compile('nodes-(\d+)\w+leaves-(\d+)\w+N-(\d+)\w+classifier-(\d+)')
m = search(p, args.resume)

args.num_nodes = int(m.group(1))
args.num_leaves = int(m.group(2))
args.num_jumps = int(m.group(3))
args.layers_dim = [int(m.group(4))]
print(f'num_nodes = \033[0;1;31m{args.num_nodes}\033[0m')
print(f'num_leaves = \033[0;1;31m{args.num_leaves}\033[0m')
print(f'num_jumps = \033[0;1;31m{args.num_jumps}\033[0m')

device = torch.device('cpu')

if args.num_leaves == 5:
    args.merge_targets = {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2, 6: 3, 7: 3, 8: 4, 9: 4}

test_loader = DataLoader(MNIST(args.data, train=False, transform=transforms.ToTensor(),
                               merge_targets=args.merge_targets),
                         batch_size=args.batch_size, shuffle=False)

model = PrototypeGraph(args.num_nodes, args.num_leaves, args.num_jumps, args.layers_dim,
                       use_representation=args.use_representation, no_root=args.no_root)
model, epoch = load_model(model, args.resume, device)
print(f'\033[0;1;33m{epoch=}\033[0m')
model.to(device)

# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.shape)


In [None]:
m.group(4)

In [None]:
m0 = torch.softmax(model.M0, dim=0).detach().numpy()
m1 = torch.softmax(model.M1, dim=0).detach().numpy()

# print(f'M0:\n{np.around(m0, decimals=1)}')
# print("=" * 110)
# print(f'M1:\n{np.around(m1, decimals=1)}')

cmap = 'hot'
# cmap = 'binary'
fig = plt.figure(figsize=(8, 8))
fig.add_subplot(121)
img = plt.imshow(m0, interpolation='nearest')
img.set_cmap(cmap)
plt.axis('off')
# plt.colorbar()
fig.add_subplot(122)
img = plt.imshow(m1, interpolation='nearest')
img.set_cmap(cmap)
plt.axis('off')
# plt.colorbar()
plt.show()
plt.close()


In [None]:
from queue import Queue


# positions for tree
def hierarchy_pos(G, root, levels=None, width=1., height=1.):
    '''If there is a cycle that is reachable from root, then this will see infinite recursion.
       G: the graph
       root: the root node
       levels: a dictionary
               key: level number (starting from 0)
               value: number of nodes in this level
       width: horizontal space allocated for drawing
       height: vertical space allocated for drawing'''
    TOTAL = "total"
    CURRENT = "current"
    def make_levels(levels, node=root, currentLevel=0, parent=None):
        """Compute the number of nodes for each level
        """
        if not currentLevel in levels:
            levels[currentLevel] = {TOTAL : 0, CURRENT : 0}
        levels[currentLevel][TOTAL] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                levels =  make_levels(levels, neighbor, currentLevel + 1, node)
        return levels

    def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
        dx = 1/levels[currentLevel][TOTAL]
        left = dx/2
        pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc)
        levels[currentLevel][CURRENT] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc-vert_gap)
        return pos
    if levels is None:
        levels = make_levels({})
    else:
        levels = {l:{TOTAL: levels[l], CURRENT:0} for l in levels}
    vert_gap = height / (max([l for l in levels])+1)
    return make_pos({})

# prepare binary with specific target probability > 0.5 or maximal among all (any source, specific target) pair
def prepare_binary(m0, m1, plot=False):
    m0_binary = np.zeros_like(m0)
    m1_binary = np.zeros_like(m0)

    q = Queue()
    processed = []
    q.put(0)
    while not q.empty():
        c = q.get()
        processed.append(c)

        r0 = m0[:, c].argmax()
        r1 = m1[:, c].argmax()
        m0_binary[r0, c] = True
        m1_binary[r1, c] = True
        if r0 < m0.shape[1] and r0 not in processed:
            q.put(r0)
        if r1 < m0.shape[1] and r1 not in processed:
            q.put(r1)

    if plot:
        fig = plt.figure(figsize=(8, 8))
        fig.add_subplot(121)
        img = plt.imshow(m0_binary, interpolation='nearest')
        img.set_cmap(cmap)
        plt.axis('off')
        fig.add_subplot(122)
        img = plt.imshow(m1_binary, interpolation='nearest')
        img.set_cmap(cmap)
        plt.axis('off')
        plt.show()
        plt.close()

    return m0_binary, m1_binary

def draw_graph_as_tree(m0, m1):
    m0_binary, m1_binary = prepare_binary(m0, m1, plot=False)

    g = nx.DiGraph(directed=True)

    # define nodes and edges existing in binary version of m0 and m1
    node_list = range(m0.shape[1])
    leaf_list = range(m0.shape[1], m0.shape[0])
    edge_list_0 = []
    edge_list_1 = []
    edge_color_0 = []
    edge_color_1 = []
    for source in range(m0.shape[1]):
        for target in range(m0.shape[0]):
            if m0_binary[target, source]:
                edge_list_0.append((source, target))
                edge_color_0.append(m0[target, source])
            if m1_binary[target, source]:
                edge_list_1.append((source, target))
                edge_color_1.append(m1[target, source])

    # add nodes and edges to graph
    g.add_nodes_from(node_list)
    g.add_nodes_from(leaf_list)
    g.add_edges_from(edge_list_0)
    g.add_edges_from(edge_list_1)

    # generate DFS tree from graph (this operation removes nodes and edges unreachable from the root)
    g = nx.bfs_tree(g, 0)

    # obtain hierachy positions of the tree
    pos = hierarchy_pos(g, 0)

    # update nodes and edges so that they correspond only to nodes reachable from the root
    node_list = [node_list[i] for i in range(len(node_list)) if node_list[i] in list(g.nodes())]
    leaf_list = [leaf_list[i] for i in range(len(leaf_list)) if leaf_list[i] in list(g.nodes())]
    edge_list_0 = []
    edge_list_1 = []
    edge_color_0 = []
    edge_color_1 = []
    for source in range(m0.shape[1]):
        for target in range(1, m0.shape[0]):
            if m0_binary[target, source] and source in list(g.nodes()) and target in list(g.nodes()):
                edge_list_0.append((source, target))
                edge_color_0.append(m0[target, source])
            if m1_binary[target, source] and source in list(g.nodes()) and target in list(g.nodes()):
                edge_list_1.append((source, target))
                edge_color_1.append(m1[target, source])

    # name leafs like the MNIST digits
    leaf_list_names = {}
    for i in leaf_list:
        leaf_list_names[i] = "D{}".format(i - m0.shape[1])
    node_list_names = {}
    for i in node_list:
        node_list_names[i] = i

    # draw graph with different color of nodes and different colors of edges (depending on their origin)
    plt.figure(figsize=(16, 16))
    nx.draw_networkx_nodes(g, pos, nodelist=node_list, node_color="tab:brown")
    nx.draw_networkx_nodes(g, pos, nodelist=leaf_list, node_color="tab:green")
    nx.draw_networkx_labels(g, pos, labels=leaf_list_names)
    nx.draw_networkx_labels(g, pos, labels=node_list_names)
    nx.draw_networkx_edges(g, pos, edgelist=edge_list_0, width=3,
                           edge_color=edge_color_0, edge_cmap=plt.cm.Reds, edge_vmin=0, edge_vmax=1,
                           connectionstyle='arc3, rad = -0.1')
    nx.draw_networkx_edges(g, pos, edgelist=edge_list_1, width=3,
                           edge_color=edge_color_1, edge_cmap=plt.cm.Greens, edge_vmin=0, edge_vmax=1,
                           connectionstyle='arc3, rad = 0.1')

    return g, pos


In [None]:
# compute test accuracy
tst_acc, total = 0, 0
with torch.no_grad():
    for i, (data, label) in tqdm(enumerate(test_loader, 0), total=len(test_loader), desc='Evaluation model'):
        data = data.to(device)
        label = label.to(device)
                
        prob = model(data)

        _, predicted = torch.max(prob, 1)
        total += label.size(0)
        tst_acc += (predicted == label).sum()

tst_acc = tst_acc.item() / total

print(f'\033[0;1;31mAccuracy: {tst_acc:.4f}\033[0m')

# draw sample paths 
x = data
l = label

graph_path = -np.ones([x.shape[0], model.num_jumps + 1], dtype=int)
graph_path[:, 0] = 0

with torch.no_grad():
    z = model.representation(x)
    P = model.transition_matrix(z)

    prob = torch.zeros(z.shape[0], model.num_nodes + model.num_leaves, dtype=z.dtype, device=z.device)
    prob[:, 0] = 1

    for i in range(model.num_jumps):
        prob = torch.einsum('bnm,bm->bn', P, prob)
        max_prob, index = torch.max(prob, dim=1)
        graph_path[:, i + 1] = index.numpy()

#     for i in range(graph_path.shape[0]):
    for i in range(0):
        g, pos = draw_graph_as_tree(m0, m1)

        edge_list = []
        edge_labels = {}
        for j in range(graph_path.shape[1] - 1):
            if graph_path[i, j] == graph_path[i, j + 1]:
                break
            pair = (graph_path[i, j], graph_path[i, j + 1])
            edge_list.append(pair)
            edge_labels[pair] = j + 1

        nx.draw_networkx_edges(g, pos, edgelist=edge_list)
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, font_color='red')
        plt.title("label: D{} ({})".format(l[i], l[i] + m0.shape[1]))


In [None]:
# calculate all possible path per label
paths_per_class = {}
for i in range(10):
    paths_per_class[i] = {}

with torch.no_grad():
    for i, (data, label) in tqdm(enumerate(test_loader, 0), total=len(test_loader), desc='Evaluation model'):
        x = data.to(device)
        l = label.to(device)

        graph_path = -np.ones([x.shape[0], model.num_jumps + 1], dtype=int)
        graph_path[:, 0] = 0

        z = model.representation(x)
        P = model.transition_matrix(z)

        prob = torch.zeros(z.shape[0], model.num_nodes + model.num_leaves, dtype=z.dtype, device=z.device)
        prob[:, 0] = 1

        for j in range(model.num_jumps):
            prob = torch.einsum('bnm,bm->bn', P, prob)
            max_prob, index = torch.max(prob, dim=1)
            graph_path[:, j + 1] = index.numpy()

        for j in range(x.shape[0]):
            path = [graph_path[j, 0]]
            for k in range(graph_path.shape[1] - 1):
#                 if graph_path[j, k] == graph_path[j, k + 1]:
                if graph_path[j, k] in range(m0.shape[1], m0.shape[0]):
                    break
                path.append(graph_path[j, k + 1])

            l_int = int(l[j])
            path_str = str(path)
            try:
                paths_per_class[l_int][path_str] += 1
            except:
                paths_per_class[l_int][path_str] = 1

for i in range(10):
    paths_per_class[i] = dict(sorted(paths_per_class[i].items(), key=lambda item: item[1], reverse=True))

# draw most common paths
from ast import literal_eval

for label in [4]: # range(10):
    paths = paths_per_class[label]

    i = 0
    for path, quantity in paths.items():
        if i > 5:
            break
        else:
            i = i + 1

        path = literal_eval(path)

        g, pos = draw_graph_as_tree(m0, m1)

        edge_list = []
        edge_labels = {}
        for j in range(len(path) - 1):
            pair = (path[j], path[j + 1])
            edge_list.append(pair)
            edge_labels[pair] = j + 1

        nx.draw_networkx_edges(g, pos, edgelist=edge_list)
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, font_color='red')
        plt.title("label: D{} ({})".format(label, quantity))
