In [1]:
from pipnet.pipnet import PIPNet, get_network
from util.log import Log
import torch.nn as nn
from util.args import get_args, save_args, get_optimizer_nn
from util.data import get_dataloaders
from util.func import init_weights_xavier
from pipnet.train import train_pipnet, test_pipnet
# from pipnet.test import eval_pipnet, get_thresholds, eval_ood
from util.eval_cub_csv import eval_prototypes_cub_parts_csv, get_topk_cub, get_proto_patches_cub
import torch
from util.vis_pipnet import visualize, visualize_topk
from util.visualize_prediction import vis_pred, vis_pred_experiments
import sys, os
import random
import numpy as np
from shutil import copy
import matplotlib.pyplot as plt
from copy import deepcopy

from omegaconf import OmegaConf
from util.node import Node
import shutil
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
import pickle
from util.func import get_patch_size
import random
from util.data import ModifiedLabelLoader
from tqdm import tqdm

Heatmaps showing where a prototype is found will not be generated because OpenCV is not installed.


In [2]:
# run_path = '/home/harishbabu/projects/PIPNet/runs/010-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/031-CUB-18-imgnet_cnext26_img=224_nprotos=20_orth-on-rel'
# run_path = '/home/harishbabu/projects/PIPNet/runs/032-CUB-18-imgnet_cnext26_img=224_nprotos=20_orth-on-rel'
run_path = '/home/harishbabu/projects/PIPNet/runs/035-CUB-18-imgnet_OOD_cnext26_img=224_nprotos=20_orth-on-rel'
args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)

if args.phylo_config:
    phylo_config = OmegaConf.load(args.phylo_config)

if args.phylo_config:
    # construct the phylo tree
    if phylo_config.phyloDistances_string == 'None':
        if '031' in run_path: # this run uses a different phylogeny file that had an extra root node which is a mistake
            root = construct_phylo_tree('/home/harishbabu/data/phlyogenyCUB/18Species-with-extra-root-node/1_tree-consensus-Hacket-18Species-modified_cub-names_v1.phy')
        else:
            root = construct_phylo_tree(phylo_config.phylogeny_path)
        print('-'*25 + ' No discretization ' + '-'*25)
    else:
        root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
        print('-'*25 + ' Discretized ' + '-'*25)
else:
    # construct the tree (original hierarchy as described in the paper)
    root = Node("root")
    root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
    root.add_children_to('animal',['non_primate','primate'])
    root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
    root.add_children_to('primate',['capuchin','gibbon','orangutan'])
    root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
    root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
    root.add_children_to('weapon',['assault_rifle','rifle'])
    # flat root
    # root.add_children(['scuba_diver','African_elephant','giant_panda','lion','capuchin','gibbon','orangutan','ambulance','pickup','sports_car','laptop','sandal','wine_bottle','assault_rifle','rifle'])
root.assign_all_descendents()

------------------------- No discretization -------------------------


In [3]:
print(root)

root
	052+053
		cub_052_Pied_billed_Grebe
		053+050
			cub_053_Western_Grebe
			050+051
				cub_050_Eared_Grebe
				cub_051_Horned_Grebe
	004+086
		004+032
			cub_004_Groove_billed_Ani
			032+033
				cub_032_Mangrove_Cuckoo
				033+031
					cub_033_Yellow_billed_Cuckoo
					cub_031_Black_billed_Cuckoo
		086+045
			cub_086_Pacific_Loon
			045+101
				045+003
					cub_045_Northern_Fulmar
					003+002
						cub_003_Sooty_Albatross
						002+001
							cub_002_Laysan_Albatross
							cub_001_Black_footed_Albatross
				101+023
					101+100
						cub_101_White_Pelican
						cub_100_Brown_Pelican
					023+025
						cub_023_Brandt_Cormorant
						025+024
							cub_025_Pelagic_Cormorant
							cub_024_Red_faced_Cormorant



In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []

args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)

ckpt_path = os.path.join(run_path, 'checkpoints', 'net_trained_last')
checkpoint = torch.load(ckpt_path, map_location=device)

# Obtain the dataset and dataloaders
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device)
if len(classes)<=20:
    if args.validation_size == 0.:
        print("Classes: ", testloader.dataset.class_to_idx, flush=True)
    else:
        print("Classes: ", str(classes), flush=True)

# Create a convolutional network based on arguments and add 1x1 conv layer
feature_net, add_on_layers, pool_layer, classification_layers, num_prototypes = get_network(len(classes), args, root=root)
   
# Create a PIP-Net
net = PIPNet(num_classes=len(classes),
                    num_prototypes=num_prototypes,
                    feature_net = feature_net,
                    args = args,
                    add_on_layers = add_on_layers,
                    pool_layer = pool_layer,
                    classification_layers = classification_layers,
                    num_parent_nodes = len(root.nodes_with_children()),
                    root = root
                    )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
net.eval()
criterion = nn.NLLLoss(reduction='mean').to(device)

# Forward one batch through the backbone to get the latent output size
# with torch.no_grad():
#     xs1, _, _ = next(iter(trainloader))
#     xs1 = xs1.to(device)
#     proto_features, _, _ = net(xs1)
#     wshape = proto_features['root'].shape[-1]
#     args.wshape = wshape #needed for calculating image patch size
#     print("Output shape: ", proto_features['root'].shape, flush=True)
    
args.wshape = 26

Num classes (k) =  18 ['cub_001_Black_footed_Albatross', 'cub_002_Laysan_Albatross', 'cub_003_Sooty_Albatross', 'cub_004_Groove_billed_Ani', 'cub_023_Brandt_Cormorant'] etc.
Classes:  {'cub_001_Black_footed_Albatross': 0, 'cub_002_Laysan_Albatross': 1, 'cub_003_Sooty_Albatross': 2, 'cub_004_Groove_billed_Ani': 3, 'cub_023_Brandt_Cormorant': 4, 'cub_024_Red_faced_Cormorant': 5, 'cub_025_Pelagic_Cormorant': 6, 'cub_031_Black_billed_Cuckoo': 7, 'cub_032_Mangrove_Cuckoo': 8, 'cub_033_Yellow_billed_Cuckoo': 9, 'cub_045_Northern_Fulmar': 10, 'cub_050_Eared_Grebe': 11, 'cub_051_Horned_Grebe': 12, 'cub_052_Pied_billed_Grebe': 13, 'cub_053_Western_Grebe': 14, 'cub_086_Pacific_Loon': 15, 'cub_100_Brown_Pelican': 16, 'cub_101_White_Pelican': 17}
Number of prototypes:  20


In [20]:
# Proto activations on Grand children

from util.data import ModifiedLabelLoader
from collections import defaultdict
import pdb

for node in root.nodes_with_children():
    if node.name == 'root':
        continue
    non_leaf_children_names = [child.name for child in node.children if not child.is_leaf()]
    if len(non_leaf_children_names) == 0: # if all the children are leaf nodes then skip this node
        continue

    name2label = projectloader.dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    modifiedLabelLoader = ModifiedLabelLoader(projectloader, node)
    coarse_label2name = modifiedLabelLoader.modifiedlabel2name
    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}

    img_iter = tqdm(enumerate(modifiedLabelLoader),
                    total=len(modifiedLabelLoader),
                    mininterval=50.,
                    desc='Collecting topk',
                    ncols=0)

    

    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
    
    # maps proto_number -> grand_child_name (or descendant leaf name) -> (mean_activation, num_images)
    proto_mean_activations = defaultdict(lambda: defaultdict(lambda: [0, 0]))

    # maps class names to the prototypes that belong to that
    class_and_prototypes = defaultdict(set)

    for i, (xs, orig_y, ys) in img_iter:
        if coarse_label2name[ys.item()] not in non_leaf_children_names:
            continue

        xs, ys = xs.to(device), ys.to(device)

        with torch.no_grad():
            pfs, pooled, _ = net(xs, inference=True)
            pooled = pooled[node.name].squeeze(0) 
            pfs = pfs[node.name].squeeze(0)

            

            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)
                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)
                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)
                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]

                if len(relevant_proto_class_names) == 0:
                    continue
                
                if (len(relevant_proto_class_names) == 1) and (relevant_proto_class_names[0] not in non_leaf_children_names):
                    continue
                
                if (len(relevant_proto_class_names) == 1) and (coarse_label2name[ys.item()] in relevant_proto_class_names):
                    child_node = root.get_node(coarse_label2name[ys.item()])
                    grand_child = child_node.closest_descendent_for(label2name[orig_y.item()])
                    proto_mean_activations[p][grand_child.name][0] = ((proto_mean_activations[p][grand_child.name][0] * \
                                                                      proto_mean_activations[p][grand_child.name][1]) + pooled[p]) / (proto_mean_activations[p][grand_child.name][1] + 1)
                    proto_mean_activations[p][grand_child.name][1] += 1

                if (len(relevant_proto_class_names) > 1) and (coarse_label2name[ys.item()] in relevant_proto_class_names):
                    child_node = root.get_node(coarse_label2name[ys.item()])
                    if child_node.is_leaf():
                        proto_mean_activations[p][child_node.name][0] = ((proto_mean_activations[p][child_node.name][0] * \
                                                                            proto_mean_activations[p][child_node.name][1]) + pooled[p]) / (proto_mean_activations[p][child_node.name][1] + 1)
                        proto_mean_activations[p][child_node.name][1] += 1
                    else:
                        grand_child = child_node.closest_descendent_for(label2name[orig_y.item()])
                        proto_mean_activations[p][grand_child.name][0] = ((proto_mean_activations[p][grand_child.name][0] * \
                                                                            proto_mean_activations[p][grand_child.name][1]) + pooled[p]) / (proto_mean_activations[p][grand_child.name][1] + 1)
                        proto_mean_activations[p][grand_child.name][1] += 1
                

                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)

            # class_and_prototypes = defaultdict(list)
            # for p in range(pooled.shape[0]):
            #     class_and_prototypes[', '.join(list(proto_mean_activations[p].keys()))].append(p)
    
    print('Node', node.name)
    for child_classname in class_and_prototypes:
        print('\t'*1, 'Child:', child_classname)
        for p in class_and_prototypes[child_classname]:
            logstr = '\t'*2 + f'Proto:{p} '
            for grand_child_name in proto_mean_activations[p]:
                mean_activation = round(proto_mean_activations[p][grand_child_name][0].item(), 2)
                num_images = proto_mean_activations[p][grand_child_name][1]
                logstr += f'{grand_child_name}:{mean_activation}/{num_images} '
            print(logstr)


Collecting topk: 540it [00:14, 37.33it/s]


052+004
	 052+053
		Proto:0 053+051:0.42/90 cub_052_Pied_billed_Grebe:0.0/30 
		Proto:36 053+051:0.08/90 cub_052_Pied_billed_Grebe:0.29/30 
		Proto:5 053+051:0.17/90 cub_052_Pied_billed_Grebe:0.16/30 
		Proto:43 053+051:0.32/90 cub_052_Pied_billed_Grebe:0.37/30 
		Proto:13 053+051:0.38/90 cub_052_Pied_billed_Grebe:0.49/30 
		Proto:46 053+051:0.47/90 cub_052_Pied_billed_Grebe:0.01/30 
		Proto:16 053+051:0.46/90 cub_052_Pied_billed_Grebe:0.0/30 
		Proto:17 053+051:0.34/90 cub_052_Pied_billed_Grebe:0.03/30 
		Proto:23 053+051:0.03/90 cub_052_Pied_billed_Grebe:0.46/30 
		Proto:31 053+051:0.2/90 cub_052_Pied_billed_Grebe:0.06/30 
	 004+086
		Proto:2 086+045:0.18/300 004+032:0.12/120 
		Proto:3 086+045:0.19/300 004+032:0.16/120 
		Proto:4 086+045:0.32/300 004+032:0.18/120 
		Proto:7 086+045:0.19/300 004+032:0.35/120 
		Proto:8 086+045:0.12/300 004+032:0.24/120 
		Proto:10 086+045:0.0/300 004+032:0.51/120 
		Proto:11 086+045:0.27/300 004+032:0.06/120 
		Proto:12 086+045:0.03/300 004+032:0.35/

Collecting topk: 120it [00:03, 36.07it/s]  


052+053
	 053+051
		Proto:2 051+050:0.49/60 cub_053_Western_Grebe:0.06/30 
		Proto:4 051+050:0.26/60 cub_053_Western_Grebe:0.39/30 
		Proto:5 051+050:0.11/60 cub_053_Western_Grebe:0.1/30 
		Proto:6 051+050:0.38/60 cub_053_Western_Grebe:0.72/30 
		Proto:7 051+050:0.17/60 cub_053_Western_Grebe:0.63/30 
		Proto:9 051+050:0.42/60 cub_053_Western_Grebe:0.83/30 
		Proto:10 051+050:0.14/60 cub_053_Western_Grebe:0.74/30 
		Proto:11 051+050:0.69/60 cub_053_Western_Grebe:0.67/30 
		Proto:15 051+050:0.4/60 cub_053_Western_Grebe:0.9/30 
		Proto:16 051+050:0.52/60 cub_053_Western_Grebe:0.43/30 
		Proto:17 051+050:0.74/60 cub_053_Western_Grebe:0.46/30 
		Proto:18 051+050:0.47/60 cub_053_Western_Grebe:0.24/30 
		Proto:20 051+050:0.18/60 cub_053_Western_Grebe:0.62/30 
		Proto:21 051+050:0.19/60 cub_053_Western_Grebe:0.26/30 
		Proto:22 051+050:0.35/60 cub_053_Western_Grebe:0.07/30 
		Proto:23 051+050:0.31/60 cub_053_Western_Grebe:0.01/30 
		Proto:24 051+050:0.41/60 cub_053_Western_Grebe:0.58/30 
		Pro

Collecting topk: 420it [00:11, 36.52it/s]


004+086
	 086+045
		Proto:1 045+100:0.18/270 cub_086_Pacific_Loon:0.48/30 
		Proto:4 045+100:0.4/270 cub_086_Pacific_Loon:0.33/30 
		Proto:5 045+100:0.13/270 cub_086_Pacific_Loon:0.0/30 
		Proto:10 045+100:0.46/270 cub_086_Pacific_Loon:0.53/30 
		Proto:11 045+100:0.31/270 cub_086_Pacific_Loon:0.01/30 
		Proto:13 045+100:0.31/270 cub_086_Pacific_Loon:0.0/30 
		Proto:14 045+100:0.16/270 cub_086_Pacific_Loon:0.27/30 
		Proto:17 045+100:0.45/270 cub_086_Pacific_Loon:0.08/30 
		Proto:19 045+100:0.24/270 cub_086_Pacific_Loon:0.0/30 
		Proto:21 045+100:0.12/270 cub_086_Pacific_Loon:0.25/30 
		Proto:28 045+100:0.23/270 cub_086_Pacific_Loon:0.35/30 
		Proto:29 045+100:0.15/270 cub_086_Pacific_Loon:0.28/30 
		Proto:32 045+100:0.32/270 cub_086_Pacific_Loon:0.6/30 
		Proto:38 045+100:0.21/270 cub_086_Pacific_Loon:0.01/30 
		Proto:39 045+100:0.29/270 cub_086_Pacific_Loon:0.03/30 
		Proto:40 045+100:0.16/270 cub_086_Pacific_Loon:0.2/30 
		Proto:41 045+100:0.15/270 cub_086_Pacific_Loon:0.28/30 
		Pro

Collecting topk: 90it [00:02, 35.97it/s]   


053+051
	 051+050
		Proto:3 cub_050_Eared_Grebe:0.21/30 cub_051_Horned_Grebe:0.4/30 
		Proto:5 cub_050_Eared_Grebe:0.05/30 cub_051_Horned_Grebe:0.2/30 
		Proto:8 cub_050_Eared_Grebe:0.11/30 cub_051_Horned_Grebe:0.23/30 
		Proto:9 cub_050_Eared_Grebe:0.57/30 cub_051_Horned_Grebe:0.6/30 
		Proto:10 cub_050_Eared_Grebe:0.46/30 cub_051_Horned_Grebe:0.62/30 
		Proto:11 cub_050_Eared_Grebe:0.41/30 cub_051_Horned_Grebe:0.5/30 
		Proto:12 cub_050_Eared_Grebe:0.74/30 cub_051_Horned_Grebe:0.73/30 
		Proto:13 cub_050_Eared_Grebe:0.58/30 cub_051_Horned_Grebe:0.58/30 
		Proto:16 cub_050_Eared_Grebe:0.55/30 cub_051_Horned_Grebe:0.39/30 
		Proto:17 cub_050_Eared_Grebe:0.5/30 cub_051_Horned_Grebe:0.63/30 
		Proto:18 cub_050_Eared_Grebe:0.22/30 cub_051_Horned_Grebe:0.45/30 
		Proto:20 cub_050_Eared_Grebe:0.16/30 cub_051_Horned_Grebe:0.3/30 
		Proto:25 cub_050_Eared_Grebe:0.27/30 cub_051_Horned_Grebe:0.47/30 
		Proto:32 cub_050_Eared_Grebe:0.33/30 cub_051_Horned_Grebe:0.34/30 
		Proto:36 cub_050_Eared_G

Collecting topk: 120it [00:03, 38.02it/s]  


004+032
	 032+031
		Proto:0 031+033:0.2/60 cub_032_Mangrove_Cuckoo:0.69/30 
		Proto:1 031+033:0.62/60 cub_032_Mangrove_Cuckoo:0.51/30 
		Proto:4 031+033:0.5/60 cub_032_Mangrove_Cuckoo:0.13/30 
		Proto:5 031+033:0.56/60 cub_032_Mangrove_Cuckoo:0.57/30 
		Proto:6 031+033:0.52/60 cub_032_Mangrove_Cuckoo:0.39/30 
		Proto:7 031+033:0.31/60 cub_032_Mangrove_Cuckoo:0.35/30 
		Proto:9 031+033:0.6/60 cub_032_Mangrove_Cuckoo:0.4/30 
		Proto:12 031+033:0.64/60 cub_032_Mangrove_Cuckoo:0.61/30 
		Proto:16 031+033:0.35/60 cub_032_Mangrove_Cuckoo:0.72/30 
		Proto:17 031+033:0.32/60 cub_032_Mangrove_Cuckoo:0.23/30 
		Proto:18 031+033:0.58/60 cub_032_Mangrove_Cuckoo:0.14/30 
		Proto:19 031+033:0.36/60 cub_032_Mangrove_Cuckoo:0.41/30 
		Proto:20 031+033:0.48/60 cub_032_Mangrove_Cuckoo:0.37/30 
		Proto:22 031+033:0.54/60 cub_032_Mangrove_Cuckoo:0.22/30 
		Proto:27 031+033:0.42/60 cub_032_Mangrove_Cuckoo:0.44/30 
		Proto:30 031+033:0.07/60 cub_032_Mangrove_Cuckoo:0.14/30 
		Proto:31 031+033:0.31/60 cub_03

Collecting topk: 300it [00:08, 37.36it/s]  


086+045
	 045+100
		Proto:2 045+003:0.02/120 100+023:0.37/150 
		Proto:3 045+003:0.07/120 100+023:0.27/150 
		Proto:5 045+003:0.24/120 100+023:0.12/150 
		Proto:6 045+003:0.14/120 100+023:0.25/150 
		Proto:7 045+003:0.05/120 100+023:0.16/150 
		Proto:9 045+003:0.15/120 100+023:0.14/150 
		Proto:10 045+003:0.14/120 100+023:0.18/150 
		Proto:11 045+003:0.22/120 100+023:0.13/150 
		Proto:12 045+003:0.1/120 100+023:0.24/150 
		Proto:14 045+003:0.22/120 100+023:0.1/150 
		Proto:17 045+003:0.15/120 100+023:0.09/150 
		Proto:18 045+003:0.22/120 100+023:0.17/150 
		Proto:19 045+003:0.19/120 100+023:0.22/150 
		Proto:20 045+003:0.47/120 100+023:0.01/150 
		Proto:21 045+003:0.47/120 100+023:0.02/150 
		Proto:23 045+003:0.05/120 100+023:0.14/150 
		Proto:24 045+003:0.23/120 100+023:0.15/150 
		Proto:25 045+003:0.1/120 100+023:0.16/150 
		Proto:26 045+003:0.05/120 100+023:0.26/150 
		Proto:27 045+003:0.16/120 100+023:0.14/150 
		Proto:29 045+003:0.14/120 100+023:0.34/150 
		Proto:30 045+003:0.07/1

Collecting topk: 90it [00:02, 36.15it/s]   


032+031
	 031+033
		Proto:1 cub_031_Black_billed_Cuckoo:0.28/30 cub_033_Yellow_billed_Cuckoo:0.44/30 
		Proto:33 cub_031_Black_billed_Cuckoo:0.39/30 cub_033_Yellow_billed_Cuckoo:0.7/30 
		Proto:35 cub_031_Black_billed_Cuckoo:0.67/30 cub_033_Yellow_billed_Cuckoo:0.51/30 
		Proto:4 cub_031_Black_billed_Cuckoo:0.85/30 cub_033_Yellow_billed_Cuckoo:0.3/30 
		Proto:36 cub_031_Black_billed_Cuckoo:0.67/30 cub_033_Yellow_billed_Cuckoo:0.48/30 
		Proto:37 cub_031_Black_billed_Cuckoo:0.39/30 cub_033_Yellow_billed_Cuckoo:0.35/30 
		Proto:43 cub_031_Black_billed_Cuckoo:0.64/30 cub_033_Yellow_billed_Cuckoo:0.39/30 
		Proto:13 cub_031_Black_billed_Cuckoo:0.55/30 cub_033_Yellow_billed_Cuckoo:0.44/30 
		Proto:14 cub_031_Black_billed_Cuckoo:0.8/30 cub_033_Yellow_billed_Cuckoo:0.54/30 
		Proto:15 cub_031_Black_billed_Cuckoo:0.59/30 cub_033_Yellow_billed_Cuckoo:0.46/30 
		Proto:46 cub_031_Black_billed_Cuckoo:0.61/30 cub_033_Yellow_billed_Cuckoo:0.43/30 
		Proto:47 cub_031_Black_billed_Cuckoo:0.36/30 cub_0

Collecting topk: 270it [00:07, 35.65it/s]


045+100
	 100+023
		Proto:32 023+024:0.1/90 100+101:0.6/60 
		Proto:2 023+024:0.56/90 100+101:0.03/60 
		Proto:35 023+024:0.03/90 100+101:0.81/60 
		Proto:6 023+024:0.22/90 100+101:0.4/60 
		Proto:7 023+024:0.68/90 100+101:0.01/60 
		Proto:8 023+024:0.36/90 100+101:0.01/60 
		Proto:40 023+024:0.49/90 100+101:0.21/60 
		Proto:41 023+024:0.21/90 100+101:0.3/60 
		Proto:11 023+024:0.29/90 100+101:0.14/60 
		Proto:44 023+024:0.49/90 100+101:0.0/60 
		Proto:14 023+024:0.25/90 100+101:0.05/60 
		Proto:46 023+024:0.3/90 100+101:0.18/60 
		Proto:19 023+024:0.26/90 100+101:0.19/60 
		Proto:22 023+024:0.4/90 100+101:0.57/60 
		Proto:25 023+024:0.42/90 100+101:0.03/60 
		Proto:27 023+024:0.38/90 100+101:0.68/60 
	 045+003
		Proto:33 003+001:0.25/90 cub_045_Northern_Fulmar:0.46/30 
		Proto:34 003+001:0.13/90 cub_045_Northern_Fulmar:0.27/30 
		Proto:3 003+001:0.23/90 cub_045_Northern_Fulmar:0.47/30 
		Proto:36 003+001:0.5/90 cub_045_Northern_Fulmar:0.53/30 
		Proto:5 003+001:0.39/90 cub_045_Norther

Collecting topk: 120it [00:03, 36.57it/s]  


045+003
	 003+001
		Proto:0 001+002:0.57/60 cub_003_Sooty_Albatross:0.4/30 
		Proto:1 001+002:0.7/60 cub_003_Sooty_Albatross:0.24/30 
		Proto:6 001+002:0.25/60 cub_003_Sooty_Albatross:0.73/30 
		Proto:8 001+002:0.31/60 cub_003_Sooty_Albatross:0.46/30 
		Proto:10 001+002:0.62/60 cub_003_Sooty_Albatross:0.38/30 
		Proto:13 001+002:0.37/60 cub_003_Sooty_Albatross:0.51/30 
		Proto:14 001+002:0.34/60 cub_003_Sooty_Albatross:0.49/30 
		Proto:17 001+002:0.58/60 cub_003_Sooty_Albatross:0.35/30 
		Proto:21 001+002:0.32/60 cub_003_Sooty_Albatross:0.67/30 
		Proto:24 001+002:0.31/60 cub_003_Sooty_Albatross:0.51/30 
		Proto:27 001+002:0.52/60 cub_003_Sooty_Albatross:0.34/30 
		Proto:28 001+002:0.65/60 cub_003_Sooty_Albatross:0.3/30 
		Proto:29 001+002:0.4/60 cub_003_Sooty_Albatross:0.17/30 
		Proto:30 001+002:0.71/60 cub_003_Sooty_Albatross:0.41/30 
		Proto:31 001+002:0.49/60 cub_003_Sooty_Albatross:0.28/30 
		Proto:37 001+002:0.43/60 cub_003_Sooty_Albatross:0.32/30 
		Proto:40 001+002:0.48/60 cub

Collecting topk: 150it [00:04, 32.36it/s]


100+023
	 023+024
		Proto:0 cub_023_Brandt_Cormorant:0.11/30 024+025:0.62/60 
		Proto:2 cub_023_Brandt_Cormorant:0.56/30 024+025:0.55/60 
		Proto:3 cub_023_Brandt_Cormorant:0.51/30 024+025:0.62/60 
		Proto:4 cub_023_Brandt_Cormorant:0.5/30 024+025:0.58/60 
		Proto:35 cub_023_Brandt_Cormorant:0.27/30 024+025:0.66/60 
		Proto:7 cub_023_Brandt_Cormorant:0.56/30 024+025:0.71/60 
		Proto:41 cub_023_Brandt_Cormorant:0.47/30 024+025:0.57/60 
		Proto:42 cub_023_Brandt_Cormorant:0.13/30 024+025:0.54/60 
		Proto:43 cub_023_Brandt_Cormorant:0.47/30 024+025:0.52/60 
		Proto:12 cub_023_Brandt_Cormorant:0.32/30 024+025:0.32/60 
		Proto:46 cub_023_Brandt_Cormorant:0.39/30 024+025:0.54/60 
		Proto:16 cub_023_Brandt_Cormorant:0.51/30 024+025:0.53/60 
		Proto:48 cub_023_Brandt_Cormorant:0.4/30 024+025:0.46/60 
		Proto:18 cub_023_Brandt_Cormorant:0.27/30 024+025:0.31/60 
		Proto:20 cub_023_Brandt_Cormorant:0.56/30 024+025:0.49/60 
		Proto:24 cub_023_Brandt_Cormorant:0.67/30 024+025:0.4/60 
		Proto:28 cub

Collecting topk: 90it [00:02, 37.36it/s]   


003+001
	 001+002
		Proto:0 cub_001_Black_footed_Albatross:0.71/30 cub_002_Laysan_Albatross:0.68/30 
		Proto:33 cub_001_Black_footed_Albatross:0.55/30 cub_002_Laysan_Albatross:0.42/30 
		Proto:3 cub_001_Black_footed_Albatross:0.46/30 cub_002_Laysan_Albatross:0.54/30 
		Proto:36 cub_001_Black_footed_Albatross:0.66/30 cub_002_Laysan_Albatross:0.43/30 
		Proto:6 cub_001_Black_footed_Albatross:0.47/30 cub_002_Laysan_Albatross:0.52/30 
		Proto:7 cub_001_Black_footed_Albatross:0.46/30 cub_002_Laysan_Albatross:0.63/30 
		Proto:40 cub_001_Black_footed_Albatross:0.37/30 cub_002_Laysan_Albatross:0.63/30 
		Proto:9 cub_001_Black_footed_Albatross:0.57/30 cub_002_Laysan_Albatross:0.65/30 
		Proto:42 cub_001_Black_footed_Albatross:0.68/30 cub_002_Laysan_Albatross:0.46/30 
		Proto:12 cub_001_Black_footed_Albatross:0.76/30 cub_002_Laysan_Albatross:0.78/30 
		Proto:17 cub_001_Black_footed_Albatross:0.8/30 cub_002_Laysan_Albatross:0.57/30 
		Proto:24 cub_001_Black_footed_Albatross:0.64/30 cub_002_Laysan

Collecting topk: 90it [00:02, 36.45it/s]   

023+024
	 024+025
		Proto:33 cub_024_Red_faced_Cormorant:0.64/30 cub_025_Pelagic_Cormorant:0.31/30 
		Proto:35 cub_024_Red_faced_Cormorant:0.32/30 cub_025_Pelagic_Cormorant:0.48/30 
		Proto:37 cub_024_Red_faced_Cormorant:0.67/30 cub_025_Pelagic_Cormorant:0.52/30 
		Proto:38 cub_024_Red_faced_Cormorant:0.69/30 cub_025_Pelagic_Cormorant:0.41/30 
		Proto:7 cub_024_Red_faced_Cormorant:0.84/30 cub_025_Pelagic_Cormorant:0.32/30 
		Proto:9 cub_024_Red_faced_Cormorant:0.85/30 cub_025_Pelagic_Cormorant:0.15/30 
		Proto:10 cub_024_Red_faced_Cormorant:0.48/30 cub_025_Pelagic_Cormorant:0.44/30 
		Proto:12 cub_024_Red_faced_Cormorant:0.84/30 cub_025_Pelagic_Cormorant:0.5/30 
		Proto:13 cub_024_Red_faced_Cormorant:0.48/30 cub_025_Pelagic_Cormorant:0.26/30 
		Proto:14 cub_024_Red_faced_Cormorant:0.45/30 cub_025_Pelagic_Cormorant:0.55/30 
		Proto:47 cub_024_Red_faced_Cormorant:0.5/30 cub_025_Pelagic_Cormorant:0.32/30 
		Proto:18 cub_024_Red_faced_Cormorant:0.76/30 cub_025_Pelagic_Cormorant:0.26/30 
		




In [24]:
# Proto activations on leaf descendents - all image-activations

from util.data import ModifiedLabelLoader
from collections import defaultdict
import heapq
import pdb


for node in root.nodes_with_children():
    if node.name == 'root':
        continue
    non_leaf_children_names = [child.name for child in node.children if not child.is_leaf()]
    if len(non_leaf_children_names) == 0: # if all the children are leaf nodes then skip this node
        continue

    name2label = projectloader.dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    modifiedLabelLoader = ModifiedLabelLoader(projectloader, node)
    coarse_label2name = modifiedLabelLoader.modifiedlabel2name
    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}

    img_iter = tqdm(enumerate(modifiedLabelLoader),
                    total=len(modifiedLabelLoader),
                    mininterval=50.,
                    desc='Collecting topk',
                    ncols=0)

    

    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
    
    # maps proto_number -> grand_child_name (or descendant leaf name) -> (mean_activation, num_images)
    proto_mean_activations = defaultdict(lambda: defaultdict(lambda: [0, 0]))

    # maps class names to the prototypes that belong to that
    class_and_prototypes = defaultdict(set)

    for i, (xs, orig_y, ys) in img_iter:
        if coarse_label2name[ys.item()] not in non_leaf_children_names:
            continue

        xs, ys = xs.to(device), ys.to(device)

        with torch.no_grad():
            pfs, pooled, _ = net(xs, inference=True)
            pooled = pooled[node.name].squeeze(0) 
            pfs = pfs[node.name].squeeze(0)

            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)
                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)
                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)
                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]

                if len(relevant_proto_class_names) == 0:
                    continue
                
                if (len(relevant_proto_class_names) == 1) and (relevant_proto_class_names[0] not in non_leaf_children_names):
                    continue
                
                if (coarse_label2name[ys.item()] in relevant_proto_class_names):
                    child_node = root.get_node(coarse_label2name[ys.item()])
                    leaf_descendent = label2name[orig_y.item()][4:7]
                    proto_mean_activations[p][leaf_descendent][0] = ((proto_mean_activations[p][leaf_descendent][0] * \
                                                                      proto_mean_activations[p][leaf_descendent][1]) + pooled[p]) / (proto_mean_activations[p][leaf_descendent][1] + 1)
                    proto_mean_activations[p][leaf_descendent][1] += 1                

                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)

    
    print('Node', node.name)
    for child_classname in class_and_prototypes:
        print('\t'*1, 'Child:', child_classname)
        for p in class_and_prototypes[child_classname]:
            logstr = '\t'*2 + f'Proto:{p} '
            for leaf_descendent in proto_mean_activations[p]:
                mean_activation = round(proto_mean_activations[p][leaf_descendent][0].item(), 2)
#                 num_images = proto_mean_activations[p][leaf_descendent][1]
#                 logstr += f'{leaf_descendent}:{mean_activation}/{num_images} '
                num_images = proto_mean_activations[p][leaf_descendent][1]
                logstr += f'{leaf_descendent}:({mean_activation}) '
            print(logstr)


Collecting topk: 540it [00:14, 37.54it/s]


Node 052+004
	 Child: 052+053
		Proto:0 050:(0.23) 051:(0.57) 052:(0.0) 053:(0.45) 
		Proto:36 050:(0.16) 051:(0.15) 052:(0.31) 053:(0.0) 
		Proto:5 050:(0.12) 051:(0.4) 052:(0.14) 053:(0.0) 
		Proto:43 050:(0.2) 051:(0.52) 052:(0.35) 053:(0.18) 
		Proto:13 050:(0.36) 051:(0.46) 052:(0.54) 053:(0.3) 
		Proto:46 050:(0.43) 051:(0.46) 052:(0.02) 053:(0.49) 
		Proto:16 050:(0.17) 051:(0.52) 052:(0.0) 053:(0.64) 
		Proto:17 050:(0.12) 051:(0.44) 052:(0.03) 053:(0.48) 
		Proto:23 050:(0.01) 051:(0.04) 052:(0.42) 053:(0.06) 
		Proto:31 050:(0.16) 051:(0.38) 052:(0.06) 053:(0.0) 
	 Child: 004+086
		Proto:2 001:(0.15) 002:(0.27) 003:(0.23) 004:(0.23) 023:(0.16) 024:(0.18) 025:(0.34) 031:(0.13) 032:(0.07) 033:(0.06) 045:(0.11) 086:(0.12) 100:(0.11) 101:(0.11) 
		Proto:3 001:(0.2) 002:(0.47) 003:(0.64) 004:(0.03) 023:(0.0) 024:(0.01) 025:(0.03) 031:(0.34) 032:(0.07) 033:(0.28) 045:(0.43) 086:(0.01) 100:(0.03) 101:(0.09) 
		Proto:4 001:(0.66) 002:(0.03) 003:(0.31) 004:(0.03) 023:(0.56) 024:(0.13)

Collecting topk: 120it [00:03, 35.92it/s]  


Node 052+053
	 Child: 053+051
		Proto:2 050:(0.5) 051:(0.54) 053:(0.04) 
		Proto:4 050:(0.17) 051:(0.33) 053:(0.36) 
		Proto:5 050:(0.08) 051:(0.12) 053:(0.1) 
		Proto:6 050:(0.35) 051:(0.34) 053:(0.7) 
		Proto:7 050:(0.01) 051:(0.31) 053:(0.66) 
		Proto:9 050:(0.38) 051:(0.44) 053:(0.81) 
		Proto:10 050:(0.12) 051:(0.21) 053:(0.68) 
		Proto:11 050:(0.68) 051:(0.71) 053:(0.62) 
		Proto:15 050:(0.49) 051:(0.3) 053:(0.87) 
		Proto:16 050:(0.52) 051:(0.54) 053:(0.36) 
		Proto:17 050:(0.72) 051:(0.8) 053:(0.38) 
		Proto:18 050:(0.39) 051:(0.47) 053:(0.22) 
		Proto:20 050:(0.01) 051:(0.35) 053:(0.57) 
		Proto:21 050:(0.23) 051:(0.16) 053:(0.22) 
		Proto:22 050:(0.36) 051:(0.4) 053:(0.08) 
		Proto:23 050:(0.18) 051:(0.41) 053:(0.01) 
		Proto:24 050:(0.4) 051:(0.44) 053:(0.55) 
		Proto:26 050:(0.31) 051:(0.39) 053:(0.45) 
		Proto:27 050:(0.35) 051:(0.46) 053:(0.5) 
		Proto:34 050:(0.43) 051:(0.47) 053:(0.18) 
		Proto:36 050:(0.5) 051:(0.44) 053:(0.77) 
		Proto:38 050:(0.01) 051:(0.24) 053:(0.

Collecting topk: 420it [00:11, 37.97it/s]


Node 004+086
	 Child: 086+045
		Proto:1 001:(0.0) 002:(0.0) 003:(0.05) 023:(0.59) 024:(0.12) 025:(0.47) 045:(0.0) 086:(0.51) 100:(0.26) 101:(0.02) 
		Proto:4 001:(0.49) 002:(0.65) 003:(0.23) 023:(0.15) 024:(0.74) 025:(0.23) 045:(0.04) 086:(0.31) 100:(0.49) 101:(0.71) 
		Proto:5 001:(0.04) 002:(0.0) 003:(0.14) 023:(0.11) 024:(0.59) 025:(0.18) 045:(0.0) 086:(0.0) 100:(0.05) 101:(0.0) 
		Proto:10 001:(0.49) 002:(0.68) 003:(0.25) 023:(0.39) 024:(0.56) 025:(0.32) 045:(0.01) 086:(0.41) 100:(0.64) 101:(0.6) 
		Proto:11 001:(0.3) 002:(0.66) 003:(0.3) 023:(0.08) 024:(0.12) 025:(0.15) 045:(0.36) 086:(0.01) 100:(0.07) 101:(0.59) 
		Proto:13 001:(0.06) 002:(0.55) 003:(0.06) 023:(0.02) 024:(0.38) 025:(0.0) 045:(0.28) 086:(0.0) 100:(0.66) 101:(0.75) 
		Proto:14 001:(0.24) 002:(0.13) 003:(0.33) 023:(0.3) 024:(0.15) 025:(0.19) 045:(0.03) 086:(0.32) 100:(0.0) 101:(0.0) 
		Proto:17 001:(0.17) 002:(0.69) 003:(0.11) 023:(0.17) 024:(0.84) 025:(0.21) 045:(0.48) 086:(0.08) 100:(0.81) 101:(0.66) 
		Proto:19 0

Collecting topk: 90it [00:02, 37.27it/s]   


Node 053+051
	 Child: 051+050
		Proto:3 050:(0.19) 051:(0.38) 
		Proto:5 050:(0.06) 051:(0.22) 
		Proto:8 050:(0.09) 051:(0.25) 
		Proto:9 050:(0.55) 051:(0.62) 
		Proto:10 050:(0.46) 051:(0.59) 
		Proto:11 050:(0.38) 051:(0.49) 
		Proto:12 050:(0.73) 051:(0.73) 
		Proto:13 050:(0.59) 051:(0.58) 
		Proto:16 050:(0.56) 051:(0.35) 
		Proto:17 050:(0.46) 051:(0.61) 
		Proto:18 050:(0.28) 051:(0.45) 
		Proto:20 050:(0.17) 051:(0.28) 
		Proto:25 050:(0.26) 051:(0.46) 
		Proto:32 050:(0.32) 051:(0.32) 
		Proto:36 050:(0.39) 051:(0.48) 
		Proto:38 050:(0.37) 051:(0.47) 
		Proto:42 050:(0.39) 051:(0.35) 
		Proto:43 050:(0.36) 051:(0.41) 
		Proto:45 050:(0.2) 051:(0.45) 
		Proto:47 050:(0.51) 051:(0.62) 
		Proto:48 050:(0.29) 051:(0.34) 
		Proto:49 050:(0.57) 051:(0.64) 


Collecting topk: 120it [00:03, 35.69it/s]  


Node 004+032
	 Child: 032+031
		Proto:0 031:(0.28) 032:(0.65) 033:(0.22) 
		Proto:1 031:(0.77) 032:(0.49) 033:(0.52) 
		Proto:4 031:(0.48) 032:(0.13) 033:(0.54) 
		Proto:5 031:(0.36) 032:(0.49) 033:(0.75) 
		Proto:6 031:(0.52) 032:(0.38) 033:(0.5) 
		Proto:7 031:(0.0) 032:(0.36) 033:(0.67) 
		Proto:9 031:(0.61) 032:(0.4) 033:(0.62) 
		Proto:12 031:(0.34) 032:(0.59) 033:(0.82) 
		Proto:16 031:(0.21) 032:(0.72) 033:(0.54) 
		Proto:17 031:(0.38) 032:(0.23) 033:(0.24) 
		Proto:18 031:(0.69) 032:(0.16) 033:(0.52) 
		Proto:19 031:(0.43) 032:(0.4) 033:(0.3) 
		Proto:20 031:(0.44) 032:(0.36) 033:(0.46) 
		Proto:22 031:(0.65) 032:(0.23) 033:(0.46) 
		Proto:27 031:(0.39) 032:(0.41) 033:(0.4) 
		Proto:30 031:(0.13) 032:(0.12) 033:(0.09) 
		Proto:31 031:(0.2) 032:(0.34) 033:(0.4) 
		Proto:32 031:(0.73) 032:(0.39) 033:(0.63) 
		Proto:36 031:(0.46) 032:(0.08) 033:(0.35) 
		Proto:37 031:(0.12) 032:(0.27) 033:(0.63) 
		Proto:38 031:(0.42) 032:(0.51) 033:(0.46) 
		Proto:40 031:(0.88) 032:(0.43) 033:(0.

Collecting topk: 300it [00:07, 37.71it/s]  


Node 086+045
	 Child: 045+100
		Proto:2 001:(0.02) 002:(0.03) 003:(0.02) 023:(0.71) 024:(0.73) 025:(0.37) 045:(0.0) 100:(0.02) 101:(0.02) 
		Proto:3 001:(0.12) 002:(0.08) 003:(0.08) 023:(0.11) 024:(0.84) 025:(0.27) 045:(0.04) 100:(0.01) 101:(0.05) 
		Proto:5 001:(0.17) 002:(0.24) 003:(0.09) 023:(0.04) 024:(0.22) 025:(0.21) 045:(0.43) 100:(0.02) 101:(0.12) 
		Proto:6 001:(0.18) 002:(0.32) 003:(0.1) 023:(0.0) 024:(0.01) 025:(0.01) 045:(0.04) 100:(0.87) 101:(0.4) 
		Proto:7 001:(0.04) 002:(0.02) 003:(0.03) 023:(0.1) 024:(0.28) 025:(0.3) 045:(0.09) 100:(0.12) 101:(0.06) 
		Proto:9 001:(0.22) 002:(0.26) 003:(0.17) 023:(0.11) 024:(0.06) 025:(0.09) 045:(0.02) 100:(0.25) 101:(0.28) 
		Proto:10 001:(0.22) 002:(0.16) 003:(0.18) 023:(0.09) 024:(0.33) 025:(0.42) 045:(0.04) 100:(0.01) 101:(0.03) 
		Proto:11 001:(0.02) 002:(0.4) 003:(0.0) 023:(0.02) 024:(0.22) 025:(0.0) 045:(0.42) 100:(0.01) 101:(0.48) 
		Proto:12 001:(0.06) 002:(0.21) 003:(0.12) 023:(0.0) 024:(0.08) 025:(0.03) 045:(0.02) 100:(0.55)

Collecting topk: 90it [00:02, 36.41it/s]   


Node 032+031
	 Child: 031+033
		Proto:1 031:(0.33) 033:(0.48) 
		Proto:33 031:(0.47) 033:(0.71) 
		Proto:35 031:(0.67) 033:(0.54) 
		Proto:4 031:(0.85) 033:(0.32) 
		Proto:36 031:(0.75) 033:(0.53) 
		Proto:37 031:(0.37) 033:(0.37) 
		Proto:43 031:(0.69) 033:(0.4) 
		Proto:13 031:(0.61) 033:(0.43) 
		Proto:14 031:(0.85) 033:(0.57) 
		Proto:15 031:(0.63) 033:(0.42) 
		Proto:46 031:(0.62) 033:(0.48) 
		Proto:47 031:(0.39) 033:(0.25) 
		Proto:24 031:(0.79) 033:(0.31) 
		Proto:28 031:(0.74) 033:(0.64) 


Collecting topk: 270it [00:07, 35.88it/s]


Node 045+100
	 Child: 100+023
		Proto:32 023:(0.01) 024:(0.16) 025:(0.08) 100:(0.63) 101:(0.6) 
		Proto:2 023:(0.4) 024:(0.83) 025:(0.43) 100:(0.02) 101:(0.04) 
		Proto:35 023:(0.03) 024:(0.0) 025:(0.05) 100:(0.85) 101:(0.81) 
		Proto:6 023:(0.0) 024:(0.69) 025:(0.02) 100:(0.44) 101:(0.4) 
		Proto:7 023:(0.67) 024:(0.74) 025:(0.64) 100:(0.01) 101:(0.02) 
		Proto:8 023:(0.31) 024:(0.41) 025:(0.35) 100:(0.06) 101:(0.0) 
		Proto:40 023:(0.24) 024:(0.77) 025:(0.5) 100:(0.2) 101:(0.17) 
		Proto:41 023:(0.0) 024:(0.57) 025:(0.08) 100:(0.09) 101:(0.5) 
		Proto:11 023:(0.04) 024:(0.81) 025:(0.0) 100:(0.08) 101:(0.17) 
		Proto:44 023:(0.61) 024:(0.63) 025:(0.2) 100:(0.02) 101:(0.0) 
		Proto:14 023:(0.45) 024:(0.18) 025:(0.11) 100:(0.12) 101:(0.0) 
		Proto:46 023:(0.45) 024:(0.04) 025:(0.46) 100:(0.4) 101:(0.03) 
		Proto:19 023:(0.01) 024:(0.73) 025:(0.04) 100:(0.07) 101:(0.31) 
		Proto:22 023:(0.47) 024:(0.24) 025:(0.52) 100:(0.42) 101:(0.72) 
		Proto:25 023:(0.18) 024:(0.38) 025:(0.7) 100:(0.0

Collecting topk: 120it [00:03, 37.70it/s]  


Node 045+003
	 Child: 003+001
		Proto:0 001:(0.64) 002:(0.51) 003:(0.45) 
		Proto:1 001:(0.73) 002:(0.6) 003:(0.29) 
		Proto:6 001:(0.18) 002:(0.23) 003:(0.81) 
		Proto:8 001:(0.39) 002:(0.29) 003:(0.44) 
		Proto:10 001:(0.53) 002:(0.6) 003:(0.34) 
		Proto:13 001:(0.35) 002:(0.32) 003:(0.55) 
		Proto:14 001:(0.45) 002:(0.18) 003:(0.44) 
		Proto:17 001:(0.62) 002:(0.55) 003:(0.33) 
		Proto:21 001:(0.12) 002:(0.59) 003:(0.67) 
		Proto:24 001:(0.25) 002:(0.34) 003:(0.46) 
		Proto:27 001:(0.57) 002:(0.49) 003:(0.33) 
		Proto:28 001:(0.46) 002:(0.75) 003:(0.34) 
		Proto:29 001:(0.36) 002:(0.36) 003:(0.2) 
		Proto:30 001:(0.77) 002:(0.62) 003:(0.34) 
		Proto:31 001:(0.42) 002:(0.52) 003:(0.27) 
		Proto:37 001:(0.38) 002:(0.45) 003:(0.34) 
		Proto:40 001:(0.45) 002:(0.55) 003:(0.33) 
		Proto:41 001:(0.3) 002:(0.83) 003:(0.22) 
		Proto:46 001:(0.43) 002:(0.54) 003:(0.13) 
		Proto:47 001:(0.71) 002:(0.16) 003:(0.13) 


Collecting topk: 150it [00:04, 33.28it/s]


Node 100+023
	 Child: 023+024
		Proto:0 023:(0.09) 024:(0.8) 025:(0.35) 
		Proto:2 023:(0.54) 024:(0.54) 025:(0.56) 
		Proto:3 023:(0.49) 024:(0.67) 025:(0.6) 
		Proto:4 023:(0.47) 024:(0.52) 025:(0.6) 
		Proto:35 023:(0.29) 024:(0.8) 025:(0.49) 
		Proto:7 023:(0.58) 024:(0.9) 025:(0.56) 
		Proto:41 023:(0.44) 024:(0.62) 025:(0.49) 
		Proto:42 023:(0.13) 024:(0.51) 025:(0.55) 
		Proto:43 023:(0.46) 024:(0.47) 025:(0.53) 
		Proto:12 023:(0.32) 024:(0.58) 025:(0.07) 
		Proto:46 023:(0.37) 024:(0.6) 025:(0.5) 
		Proto:16 023:(0.48) 024:(0.83) 025:(0.27) 
		Proto:48 023:(0.43) 024:(0.36) 025:(0.58) 
		Proto:18 023:(0.24) 024:(0.43) 025:(0.21) 
		Proto:20 023:(0.56) 024:(0.37) 025:(0.61) 
		Proto:24 023:(0.64) 024:(0.37) 025:(0.47) 
		Proto:28 023:(0.3) 024:(0.65) 025:(0.32) 
		Proto:31 023:(0.67) 024:(0.73) 025:(0.3) 
	 Child: 100+101
		Proto:1 100:(0.57) 101:(0.49) 
		Proto:33 100:(0.47) 101:(0.51) 
		Proto:36 100:(0.78) 101:(0.6) 
		Proto:6 100:(0.56) 101:(0.54) 
		Proto:8 100:(0.81) 101

Collecting topk: 90it [00:02, 38.94it/s]   


Node 003+001
	 Child: 001+002
		Proto:0 001:(0.71) 002:(0.65) 
		Proto:33 001:(0.56) 002:(0.48) 
		Proto:3 001:(0.46) 002:(0.59) 
		Proto:36 001:(0.66) 002:(0.42) 
		Proto:6 001:(0.5) 002:(0.5) 
		Proto:7 001:(0.43) 002:(0.66) 
		Proto:40 001:(0.35) 002:(0.66) 
		Proto:9 001:(0.54) 002:(0.65) 
		Proto:42 001:(0.6) 002:(0.45) 
		Proto:12 001:(0.77) 002:(0.81) 
		Proto:17 001:(0.78) 002:(0.58) 
		Proto:24 001:(0.6) 002:(0.61) 
		Proto:25 001:(0.47) 002:(0.35) 
		Proto:30 001:(0.55) 002:(0.56) 
		Proto:31 001:(0.67) 002:(0.65) 


Collecting topk: 90it [00:02, 36.78it/s]   

Node 023+024
	 Child: 024+025
		Proto:33 024:(0.61) 025:(0.29) 
		Proto:35 024:(0.28) 025:(0.47) 
		Proto:37 024:(0.68) 025:(0.47) 
		Proto:38 024:(0.63) 025:(0.44) 
		Proto:7 024:(0.78) 025:(0.28) 
		Proto:9 024:(0.81) 025:(0.17) 
		Proto:10 024:(0.51) 025:(0.41) 
		Proto:12 024:(0.83) 025:(0.49) 
		Proto:13 024:(0.45) 025:(0.3) 
		Proto:14 024:(0.46) 025:(0.57) 
		Proto:47 024:(0.48) 025:(0.32) 
		Proto:18 024:(0.77) 025:(0.25) 
		Proto:26 024:(0.44) 025:(0.52) 
		Proto:27 024:(0.35) 025:(0.27) 
		Proto:31 024:(0.51) 025:(0.19) 





In [24]:
# Proto activations on leaf descendents - topk image-activations

from util.data import ModifiedLabelLoader
from collections import defaultdict
import heapq
import pdb

topk = 10

def get_heap():
    list_ = []
    heapq.heapify(list_)
    return list_


for node in root.nodes_with_children():
    if node.name == 'root':
        continue
    non_leaf_children_names = [child.name for child in node.children if not child.is_leaf()]
    if len(non_leaf_children_names) == 0: # if all the children are leaf nodes then skip this node
        continue

    name2label = projectloader.dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    modifiedLabelLoader = ModifiedLabelLoader(projectloader, node)
    coarse_label2name = modifiedLabelLoader.modifiedlabel2name
    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}

    img_iter = tqdm(enumerate(modifiedLabelLoader),
                    total=len(modifiedLabelLoader),
                    mininterval=50.,
                    desc='Collecting topk',
                    ncols=0)

    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
    
    # maps proto_number -> grand_child_name (or descendant leaf name) -> list of top-k activations
    proto_mean_activations = defaultdict(lambda: defaultdict(get_heap))

    # maps class names to the prototypes that belong to that
    class_and_prototypes = defaultdict(set)

    for i, (xs, orig_y, ys) in img_iter:
        if coarse_label2name[ys.item()] not in non_leaf_children_names:
            continue

        xs, ys = xs.to(device), ys.to(device)

        with torch.no_grad():
            pfs, pooled, _ = net(xs, inference=False)
            pooled = pooled[node.name].squeeze(0) 
            pfs = pfs[node.name].squeeze(0)

            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)
                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)
                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)
                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]

                if len(relevant_proto_class_names) == 0:
                    continue
                
                if (len(relevant_proto_class_names) == 1) and (relevant_proto_class_names[0] not in non_leaf_children_names):
                    continue
                
                if (coarse_label2name[ys.item()] in relevant_proto_class_names):
                    child_node = root.get_node(coarse_label2name[ys.item()])
                    leaf_descendent = label2name[orig_y.item()][4:7]
                    if topk and (len(proto_mean_activations[p][leaf_descendent]) > topk):
                        heapq.heappushpop(proto_mean_activations[p][leaf_descendent], pooled[p].item())
                    else:
                        heapq.heappush(proto_mean_activations[p][leaf_descendent], pooled[p].item())

                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)

    
    print('Node', node.name)
    for child_classname in class_and_prototypes:
        print('\t'*1, 'Child:', child_classname)
        for p in class_and_prototypes[child_classname]:
            logstr = '\t'*2 + f'Proto:{p} '
            for leaf_descendent in proto_mean_activations[p]:
                mean_activation = round(np.mean(proto_mean_activations[p][leaf_descendent]), 4)
                num_images = len(proto_mean_activations[p][leaf_descendent])
                logstr += f'{leaf_descendent}:({mean_activation}) '
            print(logstr)


Collecting topk: 120it [00:04, 27.80it/s]  


Node 052+053
	 Child: 053+050
		Proto:2 050:(0.9812) 051:(0.9933) 053:(0.6588) 
		Proto:3 050:(0.9053) 051:(0.9843) 053:(0.9648) 
		Proto:4 050:(0.9931) 051:(0.9978) 053:(0.9978) 
		Proto:6 050:(0.8988) 051:(0.9403) 053:(0.4896) 
		Proto:7 050:(0.8993) 051:(0.9874) 053:(0.9988) 
		Proto:9 050:(0.8468) 051:(0.9909) 053:(0.0168) 
		Proto:10 050:(0.9056) 051:(0.9806) 053:(0.8222) 
		Proto:12 050:(0.7841) 051:(0.8298) 053:(0.9349) 
		Proto:13 050:(0.9223) 051:(0.9959) 053:(0.8779) 
		Proto:14 050:(0.8708) 051:(0.9791) 053:(0.6037) 
		Proto:16 050:(0.9934) 051:(0.9977) 053:(0.9877) 
		Proto:17 050:(0.0105) 051:(0.8373) 053:(0.9993) 
		Proto:19 050:(0.9081) 051:(0.9834) 053:(0.9472) 


Collecting topk: 420it [00:09, 45.18it/s]


Node 004+086
	 Child: 086+045
		Proto:0 001:(0.2513) 002:(0.872) 003:(0.9815) 023:(0.0043) 024:(0.0563) 025:(0.0958) 045:(0.9137) 086:(0.0388) 100:(0.002) 101:(0.0415) 
		Proto:1 001:(0.997) 002:(0.9983) 003:(0.9886) 023:(0.9976) 024:(0.9997) 025:(0.9988) 045:(0.9967) 086:(0.7949) 100:(0.9998) 101:(0.9998) 
		Proto:2 001:(0.6493) 002:(0.3787) 003:(0.6034) 023:(0.604) 024:(0.2477) 025:(0.6089) 045:(0.427) 086:(0.0314) 100:(0.6685) 101:(0.5068) 
		Proto:7 001:(0.1635) 002:(0.1849) 003:(0.0658) 023:(0.959) 024:(0.9885) 025:(0.9941) 045:(0.0094) 086:(0.1931) 100:(0.0586) 101:(0.0265) 
		Proto:10 001:(0.148) 002:(0.2155) 003:(0.1924) 023:(0.9361) 024:(0.7214) 025:(0.9685) 045:(0.0453) 086:(0.0569) 100:(0.0128) 101:(0.0163) 
		Proto:11 001:(0.9187) 002:(0.8502) 003:(0.8027) 023:(0.9992) 024:(0.9982) 025:(0.9962) 045:(0.3905) 086:(0.9992) 100:(0.0569) 101:(0.1555) 
		Proto:15 001:(0.8459) 002:(0.9892) 003:(0.933) 023:(0.3663) 024:(0.2743) 025:(0.7468) 045:(0.8844) 086:(0.5302) 100:(0.4635) 10

Collecting topk: 90it [00:03, 25.69it/s]   


Node 053+050
	 Child: 050+051
		Proto:0 050:(0.521) 051:(0.683) 
		Proto:1 050:(0.9178) 051:(0.9853) 
		Proto:3 050:(0.9916) 051:(0.9992) 
		Proto:6 050:(0.9391) 051:(0.9883) 
		Proto:12 050:(0.9874) 051:(0.9838) 
		Proto:13 050:(0.9458) 051:(0.975) 
		Proto:16 050:(0.6865) 051:(0.9197) 
		Proto:17 050:(0.7028) 051:(0.9495) 
		Proto:18 050:(0.1909) 051:(0.9628) 
		Proto:19 050:(0.9987) 051:(0.9997) 


Collecting topk: 120it [00:04, 29.96it/s]  


Node 004+032
	 Child: 032+033
		Proto:0 031:(0.64) 032:(0.9861) 033:(0.9767) 
		Proto:1 031:(0.955) 032:(0.9973) 033:(0.9998) 
		Proto:2 031:(0.9066) 032:(0.6783) 033:(0.6539) 
		Proto:4 031:(0.9999) 032:(0.9999) 033:(0.9947) 
		Proto:6 031:(0.9708) 032:(0.9904) 033:(0.997) 
		Proto:7 031:(0.9678) 032:(0.9882) 033:(0.9882) 
		Proto:8 031:(0.8236) 032:(0.9993) 033:(0.994) 
		Proto:10 031:(0.6821) 032:(0.429) 033:(0.5492) 
		Proto:12 031:(0.1021) 032:(0.9885) 033:(0.9915) 
		Proto:17 031:(0.9986) 032:(0.4277) 033:(0.7704) 
		Proto:18 031:(0.9996) 032:(0.884) 033:(0.9976) 
		Proto:19 031:(0.9998) 032:(0.3546) 033:(0.9992) 


Collecting topk: 300it [00:06, 43.00it/s]  


Node 086+045
	 Child: 045+101
		Proto:1 001:(0.9327) 002:(0.941) 003:(0.6165) 023:(0.1006) 024:(0.0545) 025:(0.1164) 045:(0.0277) 100:(0.9982) 101:(0.9943) 
		Proto:3 001:(0.9871) 002:(0.9961) 003:(0.9793) 023:(0.006) 024:(0.0688) 025:(0.0848) 045:(0.9348) 100:(0.1786) 101:(0.5683) 
		Proto:4 001:(0.2581) 002:(0.1579) 003:(0.0572) 023:(0.9996) 024:(1.0) 025:(0.9996) 045:(0.0904) 100:(0.8984) 101:(0.9998) 
		Proto:5 001:(0.9754) 002:(0.5064) 003:(0.0269) 023:(0.8366) 024:(0.9513) 025:(0.2006) 045:(0.1196) 100:(0.6982) 101:(0.9986) 
		Proto:6 001:(0.8895) 002:(0.8011) 003:(0.6659) 023:(0.9614) 024:(0.9995) 025:(0.9979) 045:(0.446) 100:(0.0988) 101:(0.1436) 
		Proto:7 001:(0.6331) 002:(0.1676) 003:(0.1906) 023:(0.8305) 024:(0.9333) 025:(0.9915) 045:(0.2363) 100:(0.0119) 101:(0.0267) 
		Proto:8 001:(0.44) 002:(0.6423) 003:(0.9982) 023:(0.08) 024:(0.0763) 025:(0.2964) 045:(0.9822) 100:(0.0187) 101:(0.1119) 
		Proto:10 001:(0.984) 002:(0.7038) 003:(0.827) 023:(0.1629) 024:(0.0555) 025:(0.137

Collecting topk: 90it [00:03, 25.70it/s]   


Node 032+033
	 Child: 033+031
		Proto:0 031:(0.9866) 033:(0.9539) 
		Proto:1 031:(0.9551) 033:(0.9908) 
		Proto:3 031:(0.9999) 033:(0.9987) 
		Proto:6 031:(0.9949) 033:(0.9947) 
		Proto:15 031:(0.9619) 033:(0.9306) 
		Proto:17 031:(0.9817) 033:(0.9623) 
		Proto:18 031:(0.8603) 033:(0.6436) 
		Proto:19 031:(0.9826) 033:(0.9972) 


Collecting topk: 270it [00:06, 38.99it/s]


Node 045+101
	 Child: 101+023
		Proto:0 023:(0.0884) 024:(0.1109) 025:(0.1002) 100:(0.9923) 101:(1.0) 
		Proto:1 023:(0.9936) 024:(0.9795) 025:(0.9858) 100:(0.7331) 101:(0.0788) 
		Proto:2 023:(0.5907) 024:(0.0124) 025:(0.2712) 100:(0.9996) 101:(0.9977) 
		Proto:3 023:(0.9828) 024:(0.9984) 025:(0.9839) 100:(0.0434) 101:(0.1284) 
		Proto:4 023:(0.7218) 024:(0.936) 025:(0.944) 100:(0.0027) 101:(0.0117) 
		Proto:6 023:(0.5329) 024:(0.9976) 025:(0.1815) 100:(0.9976) 101:(0.9987) 
		Proto:7 023:(0.9738) 024:(0.9996) 025:(0.9996) 100:(0.0839) 101:(0.0534) 
		Proto:8 023:(0.9964) 024:(0.9385) 025:(0.977) 100:(0.9199) 101:(0.9903) 
		Proto:17 023:(0.9531) 024:(0.9808) 025:(0.9845) 100:(0.4185) 101:(0.0421) 
		Proto:19 023:(0.9837) 024:(0.7099) 025:(0.9314) 100:(0.1956) 101:(0.6811) 
	 Child: 045+003
		Proto:5 001:(0.9969) 002:(0.9802) 003:(0.9992) 045:(0.9575) 
		Proto:9 001:(0.9813) 002:(0.973) 003:(0.5236) 045:(0.7514) 
		Proto:10 001:(0.5372) 002:(0.9549) 003:(0.9987) 045:(0.9757) 
		Proto:

Collecting topk: 120it [00:03, 30.15it/s]  


Node 045+003
	 Child: 003+002
		Proto:1 001:(0.9996) 002:(0.9999) 003:(0.9901) 
		Proto:2 001:(0.9988) 002:(0.3342) 003:(0.9992) 
		Proto:6 001:(0.9782) 002:(0.9997) 003:(0.5232) 
		Proto:8 001:(0.9862) 002:(0.9897) 003:(0.7589) 
		Proto:9 001:(0.9983) 002:(0.9504) 003:(0.9744) 
		Proto:10 001:(0.8371) 002:(0.769) 003:(0.8623) 
		Proto:12 001:(0.4241) 002:(0.9991) 003:(0.9808) 
		Proto:13 001:(0.9967) 002:(0.9814) 003:(0.8451) 
		Proto:17 001:(0.951) 002:(0.8812) 003:(0.989) 
		Proto:18 001:(0.9991) 002:(0.9703) 003:(0.9935) 
		Proto:19 001:(0.5967) 002:(0.9663) 003:(0.8627) 


Collecting topk: 150it [00:04, 30.04it/s]


Node 101+023
	 Child: 023+025
		Proto:2 023:(0.9925) 024:(0.9999) 025:(0.994) 
		Proto:4 023:(0.6459) 024:(0.9943) 025:(0.9398) 
		Proto:9 023:(0.9866) 024:(0.9997) 025:(0.9869) 
		Proto:10 023:(0.9657) 024:(0.8642) 025:(0.999) 
		Proto:11 023:(0.9872) 024:(0.9962) 025:(0.9848) 
		Proto:13 023:(0.9906) 024:(0.998) 025:(0.9948) 
		Proto:16 023:(0.9859) 024:(0.9008) 025:(0.9961) 
		Proto:17 023:(0.9868) 024:(0.9455) 025:(0.9972) 
		Proto:18 023:(0.9992) 024:(0.9999) 025:(0.9904) 
		Proto:19 023:(0.9889) 024:(0.9981) 025:(0.983) 
	 Child: 101+100
		Proto:5 100:(0.9682) 101:(0.963) 
		Proto:6 100:(0.9994) 101:(0.9997) 
		Proto:7 100:(0.9998) 101:(0.9992) 
		Proto:14 100:(0.9957) 101:(0.9989) 
		Proto:15 100:(0.973) 101:(1.0) 


Collecting topk: 90it [00:03, 25.72it/s]   


Node 003+002
	 Child: 002+001
		Proto:0 001:(0.9997) 002:(0.9968) 
		Proto:3 001:(0.9994) 002:(0.9977) 
		Proto:8 001:(0.9884) 002:(0.9832) 
		Proto:9 001:(0.9913) 002:(0.9616) 
		Proto:12 001:(0.9986) 002:(0.9928) 
		Proto:13 001:(0.9929) 002:(0.9881) 
		Proto:17 001:(0.9883) 002:(0.9697) 
		Proto:18 001:(0.967) 002:(0.9936) 


Collecting topk: 90it [00:03, 25.78it/s]   

Node 023+025
	 Child: 025+024
		Proto:5 024:(0.8877) 025:(0.9721) 
		Proto:7 024:(0.9997) 025:(0.9806) 
		Proto:9 024:(0.9961) 025:(0.9071) 
		Proto:10 024:(0.9947) 025:(0.9869) 
		Proto:14 024:(0.9997) 025:(0.9982) 
		Proto:15 024:(0.9834) 025:(0.8558) 
		Proto:16 024:(0.9996) 025:(0.9967) 





In [5]:
# Proto activations on leaf descendents - topk images

from util.data import ModifiedLabelLoader
from collections import defaultdict
import heapq
import pdb
from util.vis_pipnet import get_img_coordinates
import torchvision.transforms as transforms
from PIL import Image, ImageDraw as D
import torchvision

topk = 10
save_images = True

def get_heap():
    list_ = []
    heapq.heapify(list_)
    return list_

patchsize, skip = get_patch_size(args)

for node in root.nodes_with_children():
    if node.name == 'root':
        continue
    non_leaf_children_names = [child.name for child in node.children if not child.is_leaf()]
    if len(non_leaf_children_names) == 0: # if all the children are leaf nodes then skip this node
        continue

    name2label = projectloader.dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    modifiedLabelLoader = ModifiedLabelLoader(projectloader, node)
    coarse_label2name = modifiedLabelLoader.modifiedlabel2name
    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}
    
    imgs = modifiedLabelLoader.filtered_imgs

    img_iter = tqdm(enumerate(modifiedLabelLoader),
                    total=len(modifiedLabelLoader),
                    mininterval=50.,
                    desc='Collecting topk',
                    ncols=0)

    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
    
    # maps proto_number -> grand_child_name (or descendant leaf name) -> list of top-k activations
    proto_mean_activations = defaultdict(lambda: defaultdict(get_heap))

    # maps class names to the prototypes that belong to that
    class_and_prototypes = defaultdict(set)

    for i, (xs, orig_y, ys) in img_iter:
        if coarse_label2name[ys.item()] not in non_leaf_children_names:
            continue

        xs, ys = xs.to(device), ys.to(device)

        with torch.no_grad():
            softmaxes, pooled, _ = net(xs, inference=False)
            pooled = pooled[node.name].squeeze(0) 
            softmaxes = softmaxes[node.name]#.squeeze(0)

            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)
                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)
                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)
                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]
                
                # Take the max per prototype.                             
                max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0)
                max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1)
                max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes)
                
                h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]]
                w_idx = max_idx_per_prototype_w[p]

                if len(relevant_proto_class_names) == 0:
                    continue
                
                if (len(relevant_proto_class_names) == 1) and (relevant_proto_class_names[0] not in non_leaf_children_names):
                    continue
                    
                h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx)
                
                if (coarse_label2name[ys.item()] in relevant_proto_class_names):
                    child_node = root.get_node(coarse_label2name[ys.item()])
                    leaf_descendent = label2name[orig_y.item()][4:7]
                    img_to_open = imgs[i][0] # it is a tuple of (path to image, lable)
                    if topk and (len(proto_mean_activations[p][leaf_descendent]) > topk):
                        heapq.heappushpop(proto_mean_activations[p][leaf_descendent], (pooled[p].item(), img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max)))
                    else:
                        heapq.heappush(proto_mean_activations[p][leaf_descendent], (pooled[p].item(), img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max)))

                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)

    
    print('Node', node.name)
    for child_classname in class_and_prototypes:
        
        print('\t'*1, 'Child:', child_classname)
        for p in class_and_prototypes[child_classname]:
            
            logstr = '\t'*2 + f'Proto:{p} '
            for leaf_descendent in proto_mean_activations[p]:
                mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)
                num_images = len(proto_mean_activations[p][leaf_descendent])
                logstr += f'{leaf_descendent}:({mean_activation}) '
            print(logstr)
            
            if save_images:
                patches = []
                right_descriptions = []
                text_region_width = 7 # 7x the width of a patch
                for leaf_descendent, heap in proto_mean_activations[p].items():
                    heap = sorted(heap)[::-1]
                    mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)
                    for ele in heap:
                        activation, img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max) = ele
                        image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img_to_open))
                        img_tensor = transforms.ToTensor()(image)#.unsqueeze_(0) #shape (1, 3, h, w)
                        img_tensor_patch = img_tensor[:, h_coor_min:h_coor_max, w_coor_min:w_coor_max]
                        patches.append(img_tensor_patch)

                    # description on the right hand side
                    text = f'{mean_activation}, {leaf_descendent}'
                    txtimage = Image.new("RGB", (patches[0].shape[-2]*text_region_width,patches[0].shape[-1]), (0, 0, 0))
                    draw = D.Draw(txtimage)
                    draw.text((5, patches[0].shape[1]//2), text, anchor='mm', fill="white")
                    txttensor = transforms.ToTensor()(txtimage)#.unsqueeze_(0)
                    right_descriptions.append(txttensor)

                grid = torchvision.utils.make_grid(patches, nrow=topk+1, padding=1)
                grid_right_descriptions = torchvision.utils.make_grid(right_descriptions, nrow=1, padding=1)

                # merging right description with the grid of images
                grid = torch.cat([grid, grid_right_descriptions], dim=-1)

                # description on the top
                text = f'Node:{node.name}, p{p}, Child:{child_classname}'
                txtimage = Image.new("RGB", (grid.shape[-1], args.wshape), (0, 0, 0))
                draw = D.Draw(txtimage)
                draw.text((5, patches[0].shape[1]//2), text, anchor='mm', fill="white")
                txttensor = transforms.ToTensor()(txtimage)#.unsqueeze_(0)

                # merging top description with the grid of images
                grid = torch.cat([grid, txttensor], dim=1)

                os.makedirs(os.path.join(run_path, 'descendent_specific_topk', node.name), exist_ok=True)
                torchvision.utils.save_image(grid, os.path.join(run_path, 'descendent_specific_topk', node.name, f'{child_classname}-p{p}.png'))
            

Collecting topk: 120it [00:06, 18.45it/s]  


Node 052+053
	 Child: 053+050
		Proto:3 050:(0.3464) 051:(0.6158) 053:(0.9803) 
		Proto:4 050:(0.4718) 051:(0.9664) 053:(0.7302) 
		Proto:9 050:(0.9336) 051:(0.9558) 053:(0.869) 
		Proto:13 050:(0.8444) 051:(0.8137) 053:(0.9754) 
		Proto:16 050:(0.5757) 051:(0.9762) 053:(0.3438) 
		Proto:17 050:(0.3427) 051:(0.6701) 053:(0.9977) 
		Proto:19 050:(0.8315) 051:(0.9786) 053:(0.9491) 


Collecting topk: 420it [00:07, 52.73it/s]


Node 004+086
	 Child: 086+045
		Proto:1 001:(0.997) 002:(0.9861) 003:(0.9936) 023:(0.9999) 024:(0.9999) 025:(0.9987) 045:(0.2214) 086:(0.9653) 100:(0.9999) 101:(0.9995) 
		Proto:18 001:(0.9511) 002:(0.9459) 003:(0.8654) 023:(0.0904) 024:(0.9998) 025:(0.061) 045:(0.9785) 086:(0.1054) 100:(0.247) 101:(0.9468) 
		Proto:11 001:(0.8843) 002:(0.9048) 003:(0.7592) 023:(0.9984) 024:(0.9844) 025:(0.9582) 045:(0.0544) 086:(0.9767) 100:(0.9973) 101:(0.9909) 
		Proto:17 001:(0.9986) 002:(0.9951) 003:(0.9996) 023:(0.995) 024:(0.9408) 025:(0.9991) 045:(0.9894) 086:(0.999) 100:(0.9932) 101:(0.9898) 
	 Child: 004+032
		Proto:3 004:(0.936) 031:(0.9947) 032:(0.9794) 033:(0.9955) 
		Proto:6 004:(0.0015) 031:(0.9117) 032:(0.0052) 033:(0.5998) 
		Proto:8 004:(0.297) 031:(0.9245) 032:(0.9603) 033:(0.9758) 
		Proto:9 004:(0.8504) 031:(0.9794) 032:(0.996) 033:(0.9823) 
		Proto:14 004:(0.0786) 031:(0.8891) 032:(0.8984) 033:(0.8614) 


Collecting topk: 90it [00:01, 48.00it/s]   


Node 053+050
	 Child: 050+051
		Proto:0 050:(0.5434) 051:(0.8272) 
		Proto:3 050:(0.6736) 051:(0.8731) 
		Proto:6 050:(0.9505) 051:(0.9439) 
		Proto:16 050:(0.4022) 051:(0.7834) 
		Proto:18 050:(0.8933) 051:(0.9742) 
		Proto:19 050:(0.64) 051:(0.9804) 


Collecting topk: 120it [00:02, 49.93it/s]  


Node 004+032
	 Child: 032+033
		Proto:0 031:(0.9946) 032:(0.2255) 033:(0.9786) 
		Proto:2 031:(0.6317) 032:(0.921) 033:(0.891) 
		Proto:4 031:(0.3886) 032:(0.9971) 033:(0.9556) 
		Proto:6 031:(0.9643) 032:(0.9706) 033:(0.9453) 
		Proto:8 031:(0.9422) 032:(0.9908) 033:(0.5566) 
		Proto:12 031:(0.9926) 032:(0.2122) 033:(0.9493) 
		Proto:18 031:(0.999) 032:(0.9799) 033:(0.9909) 
		Proto:19 031:(0.9511) 032:(0.9913) 033:(0.994) 


Collecting topk: 300it [00:05, 55.90it/s]  


Node 086+045
	 Child: 045+101
		Proto:1 001:(0.9694) 002:(0.9954) 003:(0.4557) 023:(0.1847) 024:(0.1629) 025:(0.1749) 045:(0.0118) 100:(0.9997) 101:(0.9998) 
		Proto:3 001:(0.9738) 002:(0.6426) 003:(0.8052) 023:(0.0736) 024:(0.0238) 025:(0.3743) 045:(0.9203) 100:(0.0617) 101:(0.1105) 
		Proto:5 001:(0.9981) 002:(0.9853) 003:(0.9952) 023:(0.0175) 024:(0.0031) 025:(0.019) 045:(0.9933) 100:(0.1009) 101:(0.0401) 
		Proto:7 001:(0.1498) 002:(0.7044) 003:(0.2562) 023:(0.0846) 024:(0.0271) 025:(0.0557) 045:(0.1921) 100:(0.9525) 101:(0.9796) 
		Proto:8 001:(0.4193) 002:(0.3736) 003:(0.5577) 023:(0.981) 024:(0.9617) 025:(0.9983) 045:(0.2218) 100:(0.1905) 101:(0.0393) 
		Proto:10 001:(0.0179) 002:(0.0211) 003:(0.114) 023:(0.9985) 024:(0.9984) 025:(0.9936) 045:(0.0038) 100:(0.8938) 101:(0.1344) 
		Proto:12 001:(0.9929) 002:(0.986) 003:(0.9208) 023:(0.8328) 024:(0.4864) 025:(0.7231) 045:(0.1525) 100:(0.2511) 101:(0.0645) 
		Proto:13 001:(0.1469) 002:(0.019) 003:(0.0366) 023:(0.9996) 024:(0.9982) 0

Collecting topk: 90it [00:01, 47.84it/s]   


Node 032+033
	 Child: 033+031
		Proto:0 031:(0.9944) 033:(0.9712) 
		Proto:1 031:(0.9883) 033:(0.9595) 
		Proto:3 031:(0.9997) 033:(0.9985) 
		Proto:19 031:(0.9955) 033:(0.9681) 


Collecting topk: 270it [00:05, 52.25it/s]


Node 045+101
	 Child: 101+023
		Proto:1 023:(0.9774) 024:(0.9969) 025:(0.9851) 100:(0.9916) 101:(0.9987) 
		Proto:2 023:(0.206) 024:(0.0357) 025:(0.4072) 100:(0.9885) 101:(0.9857) 
		Proto:6 023:(0.9972) 024:(0.9954) 025:(0.9729) 100:(0.9892) 101:(0.997) 
		Proto:7 023:(0.0921) 024:(0.1993) 025:(0.087) 100:(0.9921) 101:(0.9981) 
		Proto:8 023:(0.9961) 024:(0.9862) 025:(0.9989) 100:(0.9986) 101:(0.9994) 
		Proto:19 023:(0.9874) 024:(0.9995) 025:(0.9948) 100:(0.0276) 101:(0.0021) 
	 Child: 045+003
		Proto:5 001:(0.9985) 002:(0.9964) 003:(0.9994) 045:(0.992) 
		Proto:9 001:(0.9865) 002:(0.7353) 003:(0.9954) 045:(0.9459) 
		Proto:10 001:(0.979) 002:(0.9684) 003:(0.9944) 045:(0.8819) 
		Proto:13 001:(0.9649) 002:(0.9965) 003:(0.9903) 045:(0.9912) 
		Proto:16 001:(0.9965) 002:(0.9996) 003:(0.915) 045:(0.7945) 


Collecting topk: 120it [00:02, 49.16it/s]  


Node 045+003
	 Child: 003+002
		Proto:1 001:(0.8833) 002:(0.9829) 003:(0.9798) 
		Proto:2 001:(0.9785) 002:(0.8876) 003:(0.8433) 
		Proto:6 001:(0.7035) 002:(0.8692) 003:(0.8956) 
		Proto:7 001:(0.8719) 002:(0.9192) 003:(0.7222) 
		Proto:9 001:(0.9802) 002:(0.9616) 003:(0.8611) 
		Proto:11 001:(0.7903) 002:(0.8737) 003:(0.9919) 
		Proto:13 001:(0.9946) 002:(0.6873) 003:(0.912) 
		Proto:17 001:(0.9968) 002:(0.9975) 003:(0.9804) 
		Proto:18 001:(0.9768) 002:(0.9807) 003:(0.9698) 


Collecting topk: 150it [00:03, 43.42it/s]


Node 101+023
	 Child: 023+025
		Proto:4 023:(0.9786) 024:(0.9956) 025:(0.9881) 
		Proto:9 023:(0.9653) 024:(0.9989) 025:(0.9683) 
		Proto:10 023:(0.9862) 024:(0.9992) 025:(0.9704) 
		Proto:11 023:(0.9759) 024:(0.9829) 025:(0.9649) 
		Proto:12 023:(0.87) 024:(0.8798) 025:(0.9419) 
		Proto:13 023:(0.9865) 024:(0.9798) 025:(0.9961) 
		Proto:18 023:(0.9989) 024:(0.9985) 025:(0.9949) 
	 Child: 101+100
		Proto:15 100:(0.9955) 101:(0.999) 
		Proto:5 100:(0.9992) 101:(0.9994) 
		Proto:6 100:(0.9998) 101:(1.0) 
		Proto:7 100:(0.9996) 101:(0.9992) 


Collecting topk: 90it [00:01, 48.76it/s]   


Node 003+002
	 Child: 002+001
		Proto:0 001:(0.9903) 002:(0.9978) 
		Proto:9 001:(0.9635) 002:(0.9451) 
		Proto:12 001:(0.9817) 002:(0.9971) 
		Proto:13 001:(0.9974) 002:(0.9977) 
		Proto:17 001:(0.9991) 002:(0.9098) 


Collecting topk: 90it [00:01, 47.20it/s]   


Node 023+025
	 Child: 025+024
		Proto:0 024:(0.9801) 025:(0.9024) 
		Proto:6 024:(0.9964) 025:(0.9902) 
		Proto:7 024:(0.9901) 025:(0.8848) 
		Proto:9 024:(0.9991) 025:(0.9839) 
		Proto:10 024:(0.9988) 025:(0.7854) 
		Proto:11 024:(0.9925) 025:(0.9013) 
		Proto:14 024:(0.9985) 025:(0.9131) 


In [6]:
# Proto activations on leaf descendents - topk images

from util.data import ModifiedLabelLoader
from collections import defaultdict
import heapq
import pdb
from util.vis_pipnet import get_img_coordinates
import torchvision.transforms as transforms
from PIL import Image, ImageDraw as D
import torchvision

topk = 10
save_images = True

def get_heap():
    list_ = []
    heapq.heapify(list_)
    return list_

patchsize, skip = get_patch_size(args)

for node in root.nodes_with_children():
    if node.name == 'root':
        continue
    non_leaf_children_names = [child.name for child in node.children if not child.is_leaf()]
    if len(non_leaf_children_names) == 0: # if all the children are leaf nodes then skip this node
        continue

    name2label = projectloader.dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    modifiedLabelLoader = ModifiedLabelLoader(projectloader, node)
    coarse_label2name = modifiedLabelLoader.modifiedlabel2name
    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}
    
    imgs = modifiedLabelLoader.filtered_imgs

    img_iter = tqdm(enumerate(modifiedLabelLoader),
                    total=len(modifiedLabelLoader),
                    mininterval=50.,
                    desc='Collecting topk',
                    ncols=0)

    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
    
    # maps proto_number -> grand_child_name (or descendant leaf name) -> list of top-k activations
    proto_mean_activations = defaultdict(lambda: defaultdict(get_heap))

    # maps class names to the prototypes that belong to that
    class_and_prototypes = defaultdict(set)

    for i, (xs, orig_y, ys) in img_iter:
        if coarse_label2name[ys.item()] not in non_leaf_children_names:
            continue

        xs, ys = xs.to(device), ys.to(device)

        with torch.no_grad():
            softmaxes, pooled, _ = net(xs, inference=False)
            pooled = pooled[node.name].squeeze(0) 
            softmaxes = softmaxes[node.name]#.squeeze(0)

            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)
                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)
                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)
                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]
                relevant_class_weight = classification_weights[:, p][classification_weights[:, p] > 1e-3]
                
                if len(relevant_proto_class_names) == 0:
                    continue
                
                if (len(relevant_proto_class_names) == 1) and (relevant_proto_class_names[0] not in non_leaf_children_names):
                    continue
                
                if len(relevant_proto_class_names) > 1:
                    raise Exception(f"P{p} of node {node.name} is relevant to more than one class {relevant_proto_class_names}")
                
                relevant_class_weight = relevant_class_weight.item()
                
                # Take the max per prototype.                             
                max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0)
                max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1)
                max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes)
                
                h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]]
                w_idx = max_idx_per_prototype_w[p]

                
                    
                h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx)
                
                if (coarse_label2name[ys.item()] in relevant_proto_class_names):
                    child_node = root.get_node(coarse_label2name[ys.item()])
                    leaf_descendent = label2name[orig_y.item()][4:7]
                    img_to_open = imgs[i][0] # it is a tuple of (path to image, lable)
                    if topk and (len(proto_mean_activations[p][leaf_descendent]) > topk):
                        heapq.heappushpop(proto_mean_activations[p][leaf_descendent], (pooled[p].item()*relevant_class_weight, img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max)))
                    else:
                        heapq.heappush(proto_mean_activations[p][leaf_descendent], (pooled[p].item()*relevant_class_weight, img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max)))

                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)

    
    print('Node', node.name)
    for child_classname in class_and_prototypes:
        
        print('\t'*1, 'Child:', child_classname)
        for p in class_and_prototypes[child_classname]:
            
            logstr = '\t'*2 + f'Proto:{p} '
            for leaf_descendent in proto_mean_activations[p]:
                mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)
                num_images = len(proto_mean_activations[p][leaf_descendent])
                logstr += f'{leaf_descendent}:({mean_activation}) '
            print(logstr)
            
            if save_images:
                patches = []
                right_descriptions = []
                text_region_width = 7 # 7x the width of a patch
                for leaf_descendent, heap in proto_mean_activations[p].items():
                    heap = sorted(heap)[::-1]
                    mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)
                    for ele in heap:
                        activation, img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max) = ele
                        image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img_to_open))
                        img_tensor = transforms.ToTensor()(image)#.unsqueeze_(0) #shape (1, 3, h, w)
                        img_tensor_patch = img_tensor[:, h_coor_min:h_coor_max, w_coor_min:w_coor_max]
                        patches.append(img_tensor_patch)

                    # description on the right hand side
                    text = f'{mean_activation}, {leaf_descendent}'
                    txtimage = Image.new("RGB", (patches[0].shape[-2]*text_region_width,patches[0].shape[-1]), (0, 0, 0))
                    draw = D.Draw(txtimage)
                    draw.text((5, patches[0].shape[1]//2), text, anchor='mm', fill="white")
                    txttensor = transforms.ToTensor()(txtimage)#.unsqueeze_(0)
                    right_descriptions.append(txttensor)

                grid = torchvision.utils.make_grid(patches, nrow=topk+1, padding=1)
                grid_right_descriptions = torchvision.utils.make_grid(right_descriptions, nrow=1, padding=1)

                # merging right description with the grid of images
                grid = torch.cat([grid, grid_right_descriptions], dim=-1)

                # description on the top
                text = f'Node:{node.name}, p{p}, Child:{child_classname}'
                txtimage = Image.new("RGB", (grid.shape[-1], args.wshape), (0, 0, 0))
                draw = D.Draw(txtimage)
                draw.text((5, patches[0].shape[1]//2), text, anchor='mm', fill="white")
                txttensor = transforms.ToTensor()(txtimage)#.unsqueeze_(0)

                # merging top description with the grid of images
                grid = torch.cat([grid, txttensor], dim=1)

                os.makedirs(os.path.join(run_path, 'descendent_specific_topk_scores', node.name), exist_ok=True)
                torchvision.utils.save_image(grid, os.path.join(run_path, 'descendent_specific_topk_scores', node.name, f'{child_classname}-p{p}.png'))
            

Collecting topk: 120it [00:02, 50.00it/s]  


Node 052+053
	 Child: 053+050
		Proto:3 050:(0.9529) 051:(1.6938) 053:(2.6966) 
		Proto:4 050:(0.967) 051:(1.9807) 053:(1.4966) 
		Proto:9 050:(2.6367) 051:(2.6995) 053:(2.4543) 
		Proto:13 050:(1.4765) 051:(1.4228) 053:(1.7056) 
		Proto:16 050:(1.8832) 051:(3.1936) 053:(1.1246) 
		Proto:17 050:(0.4404) 051:(0.8612) 053:(1.2822) 
		Proto:19 050:(3.4319) 051:(4.0389) 053:(3.9173) 


Collecting topk: 420it [00:07, 54.07it/s]


Node 004+086
	 Child: 086+045
		Proto:1 001:(4.3188) 002:(4.2716) 003:(4.3041) 023:(4.3316) 024:(4.3317) 025:(4.3261) 045:(0.9593) 086:(4.1818) 100:(4.3315) 101:(4.3298) 
		Proto:18 001:(2.3883) 002:(2.3751) 003:(2.1731) 023:(0.227) 024:(2.5106) 025:(0.1532) 045:(2.4572) 086:(0.2646) 100:(0.6201) 101:(2.3774) 
		Proto:11 001:(2.9054) 002:(2.9728) 003:(2.4945) 023:(3.2803) 024:(3.2345) 025:(3.1484) 045:(0.1788) 086:(3.209) 100:(3.2767) 101:(3.2558) 
		Proto:17 001:(4.8714) 002:(4.8543) 003:(4.8761) 023:(4.854) 024:(4.5893) 025:(4.8736) 045:(4.8267) 086:(4.8735) 100:(4.8449) 101:(4.8283) 
	 Child: 004+032
		Proto:3 004:(2.9379) 031:(3.1223) 032:(3.0742) 033:(3.1248) 
		Proto:6 004:(0.0001) 031:(0.0526) 032:(0.0003) 033:(0.0346) 
		Proto:8 004:(0.076) 031:(0.2366) 032:(0.2457) 033:(0.2497) 
		Proto:9 004:(3.3763) 031:(3.8886) 032:(3.9546) 033:(3.9003) 
		Proto:14 004:(0.2145) 031:(2.4269) 032:(2.4523) 033:(2.3512) 


Collecting topk: 90it [00:01, 47.21it/s]   


Node 053+050
	 Child: 050+051
		Proto:0 050:(1.2862) 051:(1.9582) 
		Proto:3 050:(2.0965) 051:(2.7172) 
		Proto:6 050:(3.141) 051:(3.1193) 
		Proto:16 050:(0.9455) 051:(1.8418) 
		Proto:18 050:(0.1036) 051:(0.113) 
		Proto:19 050:(2.3963) 051:(3.6708) 


Collecting topk: 120it [00:02, 43.39it/s]  


Node 004+032
	 Child: 032+033
		Proto:0 031:(1.0536) 032:(0.2389) 033:(1.0367) 
		Proto:2 031:(0.4491) 032:(0.6548) 033:(0.6335) 
		Proto:4 031:(1.3148) 032:(3.3738) 033:(3.2335) 
		Proto:6 031:(3.3762) 032:(3.3981) 033:(3.3096) 
		Proto:8 031:(1.38) 032:(1.4513) 033:(0.8152) 
		Proto:12 031:(2.8455) 032:(0.6083) 033:(2.7214) 
		Proto:18 031:(2.7017) 032:(2.6501) 033:(2.6797) 
		Proto:19 031:(3.3279) 032:(3.4688) 033:(3.478) 


Collecting topk: 300it [00:05, 53.96it/s]  


Node 086+045
	 Child: 045+101
		Proto:1 001:(2.2838) 002:(2.3451) 003:(1.0736) 023:(0.435) 024:(0.3838) 025:(0.412) 045:(0.0277) 100:(2.3552) 101:(2.3553) 
		Proto:3 001:(2.532) 002:(1.6709) 003:(2.0938) 023:(0.1914) 024:(0.0619) 025:(0.9732) 045:(2.3931) 100:(0.1605) 101:(0.2872) 
		Proto:5 001:(3.9396) 002:(3.8892) 003:(3.9283) 023:(0.0693) 024:(0.0122) 025:(0.075) 045:(3.9209) 100:(0.3982) 101:(0.1584) 
		Proto:7 001:(0.151) 002:(0.7098) 003:(0.2582) 023:(0.0852) 024:(0.0273) 025:(0.0562) 045:(0.1936) 100:(0.9599) 101:(0.9871) 
		Proto:8 001:(1.3479) 002:(1.201) 003:(1.7929) 023:(3.1537) 024:(3.0916) 025:(3.2093) 045:(0.7131) 100:(0.6123) 101:(0.1263) 
		Proto:10 001:(0.0676) 002:(0.0797) 003:(0.4305) 023:(3.7718) 024:(3.7714) 025:(3.7531) 045:(0.0143) 100:(3.3761) 101:(0.5077) 
		Proto:12 001:(2.1831) 002:(2.1678) 003:(2.0246) 023:(1.831) 024:(1.0694) 025:(1.5898) 045:(0.3354) 100:(0.5522) 101:(0.1418) 
		Proto:13 001:(0.4457) 002:(0.0577) 003:(0.1112) 023:(3.0335) 024:(3.0292) 025

Collecting topk: 90it [00:01, 46.99it/s]   


Node 032+033
	 Child: 033+031
		Proto:0 031:(2.9883) 033:(2.9185) 
		Proto:1 031:(1.2407) 033:(1.2046) 
		Proto:3 031:(3.8938) 033:(3.8892) 
		Proto:19 031:(2.1762) 033:(2.1162) 


Collecting topk: 270it [00:05, 50.67it/s]


Node 045+101
	 Child: 101+023
		Proto:1 023:(3.1222) 024:(3.1844) 025:(3.1468) 100:(3.1675) 101:(3.1903) 
		Proto:2 023:(0.4542) 024:(0.0787) 025:(0.8978) 100:(2.1796) 101:(2.1735) 
		Proto:6 023:(3.6944) 024:(3.6878) 025:(3.6042) 100:(3.6648) 101:(3.6934) 
		Proto:7 023:(0.0184) 024:(0.0398) 025:(0.0174) 100:(0.1979) 101:(0.1991) 
		Proto:8 023:(3.2148) 024:(3.1827) 025:(3.2237) 100:(3.223) 101:(3.2253) 
		Proto:19 023:(1.4577) 024:(1.4756) 025:(1.4686) 100:(0.0408) 101:(0.0031) 
	 Child: 045+003
		Proto:5 001:(3.222) 002:(3.2153) 003:(3.2248) 045:(3.2009) 
		Proto:9 001:(2.2476) 002:(1.6751) 003:(2.2679) 045:(2.1549) 
		Proto:10 001:(2.4256) 002:(2.3993) 003:(2.4638) 045:(2.1852) 
		Proto:13 001:(3.008) 002:(3.1065) 003:(3.0871) 045:(3.0902) 
		Proto:16 001:(1.4014) 002:(1.4057) 003:(1.2867) 045:(1.1173) 


Collecting topk: 120it [00:02, 48.80it/s]  


Node 045+003
	 Child: 003+002
		Proto:1 001:(1.487) 002:(1.6546) 003:(1.6494) 
		Proto:2 001:(1.7784) 002:(1.6133) 003:(1.5328) 
		Proto:6 001:(0.3732) 002:(0.4611) 003:(0.4751) 
		Proto:7 001:(0.1149) 002:(0.1212) 003:(0.0952) 
		Proto:9 001:(2.0008) 002:(1.9628) 003:(1.7577) 
		Proto:11 001:(1.625) 002:(1.7965) 003:(2.0395) 
		Proto:13 001:(2.5494) 002:(1.7616) 003:(2.3377) 
		Proto:17 001:(4.3339) 002:(4.337) 003:(4.2627) 
		Proto:18 001:(2.5026) 002:(2.5126) 003:(2.4846) 


Collecting topk: 150it [00:03, 42.30it/s]


Node 101+023
	 Child: 023+025
		Proto:4 023:(1.9598) 024:(1.9939) 025:(1.9788) 
		Proto:9 023:(1.4488) 024:(1.4992) 025:(1.4533) 
		Proto:10 023:(1.4015) 024:(1.4199) 025:(1.379) 
		Proto:11 023:(2.6012) 024:(2.6199) 025:(2.5719) 
		Proto:12 023:(0.1511) 024:(0.1528) 025:(0.1636) 
		Proto:13 023:(2.9055) 024:(2.8857) 025:(2.9337) 
		Proto:18 023:(3.3207) 024:(3.3193) 025:(3.3072) 
	 Child: 101+100
		Proto:15 100:(2.0381) 101:(2.0453) 
		Proto:5 100:(1.4792) 101:(1.4796) 
		Proto:6 100:(3.6827) 101:(3.6832) 
		Proto:7 100:(3.7665) 101:(3.7648) 


Collecting topk: 90it [00:01, 48.18it/s]   


Node 003+002
	 Child: 002+001
		Proto:0 001:(3.1762) 002:(3.2003) 
		Proto:9 001:(0.1948) 002:(0.1911) 
		Proto:12 001:(3.589) 002:(3.6451) 
		Proto:13 001:(3.5498) 002:(3.551) 
		Proto:17 001:(2.4816) 002:(2.26) 


Collecting topk: 90it [00:01, 46.66it/s]   


Node 023+025
	 Child: 025+024
		Proto:0 024:(0.8611) 025:(0.7929) 
		Proto:6 024:(0.5784) 025:(0.5748) 
		Proto:7 024:(2.3964) 025:(2.1413) 
		Proto:9 024:(2.9654) 025:(2.9203) 
		Proto:10 024:(1.6659) 025:(1.31) 
		Proto:11 024:(1.7408) 025:(1.5809) 
		Proto:14 024:(2.679) 025:(2.4499) 
