In [10]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
import os
import torchvision.transforms as transforms
from pipnet.pipnet import PIPNet, get_network
from util.data import get_dataloaders
from util.vis_pipnet import get_img_coordinates
from util.vis_pipnet import visualize, visualize_topk
from util.func import get_patch_size
from util.eval_cub_csv import get_topk_cub
from PIL import ImageFont, Image, ImageDraw as D
from pipnet.train import test_pipnet, train_pipnet
from omegaconf import OmegaConf
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
from util.args import get_args, save_args, get_optimizer_nn
import wandb
from torchvision.datasets.folder import ImageFolder
import pdb
import math
print(torch.cuda.is_available())

!which python

True
/home/harishbabu/.conda/envs/hpnet1/bin/python


In [4]:
# run_path = '/home/harishbabu/projects/PIPNet/runs/004-CUB-27-imgnet_cnext26_img=224_nprotos=200'
# run_path = '/home/harishbabu/projects/PIPNet/runs/005-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/007-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/009-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/010-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/012-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/013-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
run_path = '/home/harishbabu/projects/PIPNet/runs/018-CUB-27-imgnet_cnext26_img=224_nprotos=20'

if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []

args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)
args.OOD_dataset = 'CUB-163-OOD-imgnet-224'

ckpt_epoch = 10

ckpt_path = os.path.join(run_path, 'checkpoints', 'net_trained_'+str(ckpt_epoch))
checkpoint = torch.load(ckpt_path, map_location=device)

In [15]:
args.phylo_config = './configs/cub08_phylogeny.yaml'
if args.phylo_config:
    phylo_config = OmegaConf.load(args.phylo_config)

if args.phylo_config:
    # construct the phylo tree
    if phylo_config.phyloDistances_string == 'None':
        root = construct_phylo_tree(phylo_config.phylogeny_path)
        print('-'*25 + ' No discretization ' + '-'*25)
    else:
        root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
        print('-'*25 + ' Discretized ' + '-'*25)
else:
    # construct the tree (original hierarchy as described in the paper)
    root = Node("root")
    root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
    root.add_children_to('animal',['non_primate','primate'])
    root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
    root.add_children_to('primate',['capuchin','gibbon','orangutan'])
    root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
    root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
    root.add_children_to('weapon',['assault_rifle','rifle'])
    # flat root
    # root.add_children(['scuba_diver','African_elephant','giant_panda','lion','capuchin','gibbon','orangutan','ambulance','pickup','sports_car','laptop','sandal','wine_bottle','assault_rifle','rifle'])
root.assign_all_descendents()

print(root)

------------------------- No discretization -------------------------
root
	016+181
		cub_016_Painted_Bunting
		181+097
			181+161
				cub_181_Worm_eating_Warbler
				161+165
					cub_161_Blue_winged_Warbler
					cub_165_Chestnut_sided_Warbler
			097+122
				097+011
					cub_097_Orchard_Oriole
					cub_011_Rusty_Blackbird
				122+113
					cub_122_Harris_Sparrow
					cub_113_Baird_Sparrow



In [6]:
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device, OOD=False)
trainloader_OOD, trainloader_pretraining_OOD, trainloader_normal_OOD, trainloader_normal_augment_OOD, projectloader_OOD, testloader_OOD, test_projectloader_OOD, _ = get_dataloaders(args, device, OOD=True)

Num classes (k) =  27 ['cub_001_Black_footed_Albatross', 'cub_011_Rusty_Blackbird', 'cub_016_Painted_Bunting', 'cub_019_Gray_Catbird', 'cub_030_Fish_Crow'] etc.
Dropping 21 samples from trainloader_pretraining
Num classes (k) =  163 ['cub_002_Laysan_Albatross', 'cub_003_Sooty_Albatross', 'cub_004_Groove_billed_Ani', 'cub_005_Crested_Auklet', 'cub_006_Least_Auklet'] etc.


In [7]:
# Create a convolutional network based on arguments and add 1x1 conv layer
feature_net, add_on_layers, pool_layer, classification_layers, num_prototypes = get_network(len(classes), args, root=root)
   
# Create a PIP-Net
net = PIPNet(num_classes=len(classes),
                    num_prototypes=num_prototypes,
                    feature_net = feature_net,
                    args = args,
                    add_on_layers = add_on_layers,
                    pool_layer = pool_layer,
                    classification_layers = classification_layers,
                    num_parent_nodes = len(root.nodes_with_children()),
                    root = root
                    )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
criterion = nn.NLLLoss(reduction='mean').to(device)

Number of prototypes:  20


In [12]:
with torch.no_grad():
        xs1, _, _ = next(iter(trainloader))
        xs1 = xs1.to(device)
        proto_features, _, _ = net(xs1)
        wshape = proto_features['root'].shape[-1]
        args.wshape = wshape #needed for calculating image patch size
        print("Output shape: ", proto_features['root'].shape, flush=True)

Output shape:  torch.Size([64, 20, 26, 26])


In [8]:
# SET SMALL WEIGHTS TO ZERO
with torch.no_grad():
    torch.set_printoptions(profile="full")
    for attr in dir(net.module):
        if attr.endswith('_classification'):
            getattr(net.module, attr).weight.copy_(torch.clamp(getattr(net.module, attr).weight.data - 0.001, min=0.)) 
            print(f"{attr} weights: ", getattr(net.module, attr).weight[getattr(net.module, attr).weight.nonzero(as_tuple=True)], \
                  (getattr(net.module, attr).weight[getattr(net.module, attr).weight.nonzero(as_tuple=True)]).shape, flush=True)
            if args.bias:
                print(f"{attr} bias: ", getattr(net.module, attr).bias, flush=True)

_001+033_classification weights:  tensor([1.2527, 0.5492, 0.6577, 0.5446, 0.6620, 1.0300, 0.8236, 0.8055, 0.8551,
        0.7457, 0.9521, 0.8871, 0.5308, 0.6835, 1.0988, 0.9804, 0.5986, 2.0822,
        0.2149, 0.8755, 1.7046, 0.6834, 1.1044, 0.7179, 0.4011, 2.3770, 0.7406,
        0.7218, 0.7058, 0.8697, 0.6373, 0.6767, 0.9792, 0.8126, 0.2888, 0.7840,
        0.8037], device='cuda:0') torch.Size([37])
_001+052_classification weights:  tensor([0.9740, 0.5849, 1.0898, 2.3066, 1.1345, 1.0837, 1.2009, 0.1623, 0.9263,
        1.4857, 0.7822, 0.8057, 1.0692, 0.5540, 0.9848, 1.2713, 0.9706, 0.7821,
        1.0589, 0.5269, 1.7319, 1.0501, 0.5331, 0.6632, 0.5687, 0.8421, 1.5807,
        0.4428, 0.1836, 0.7065, 0.9953, 0.5378, 1.1638, 0.4059, 0.3501, 0.7237,
        0.7092, 0.6144], device='cuda:0') torch.Size([38])
_011+097_classification weights:  tensor([1.0461, 0.8292, 0.6568, 0.1928, 1.4023, 0.7130, 0.6120, 0.5622, 0.5374,
        0.4139, 0.8673, 0.8223, 0.4411, 0.6061, 1.2021, 0.8795, 0.83

_165+161_classification weights:  tensor([1.0083, 1.0126, 0.8060, 0.6430, 0.8119, 1.4438, 0.9061, 0.7766, 1.4878,
        0.2954, 1.3330, 0.7098, 1.0528, 0.5782, 0.2641, 0.7989, 1.1762, 0.9185,
        0.8117, 0.7630, 0.7129, 0.8642, 0.9697, 0.7509, 0.3542, 0.6390, 0.9533,
        0.2624, 1.2297, 0.3526, 0.9480, 0.6770, 1.0599, 1.2184, 0.8308, 0.5280,
        0.3890, 2.3402, 0.8102], device='cuda:0') torch.Size([39])
_165+181_classification weights:  tensor([0.9386, 1.1016, 1.0464, 0.9703, 1.1304, 0.9817, 1.0875, 1.0224, 0.8527,
        1.1275, 0.8968, 1.0555, 1.2416, 0.8093, 0.9245, 1.1461, 1.0536, 1.0085,
        0.2914, 0.3324, 0.5007, 0.4102, 0.4447, 0.4389, 0.7309, 2.8073, 0.4853,
        0.4585, 0.6047, 0.2998, 0.5373, 0.6694, 0.1323, 0.7068, 0.9621, 0.4383,
        0.5932, 0.4226, 1.4222], device='cuda:0') torch.Size([39])
_187+079_classification weights:  tensor([1.0967, 0.5341, 0.8547, 0.9251, 1.4282, 0.8067, 0.8161, 0.8277, 0.4337,
        0.6162, 0.5845, 0.7878, 0.6150, 1.27

In [14]:
for node in root.nodes_with_children():
    topks = visualize_topk(net, projectloader, node.num_children(), device, f'visualised_prototypes_topk_ep={ckpt_epoch}/{node.name}', args, node=node, wandb_logging=False)
    # set weights of prototypes that are never really found in projection set to 0
    set_to_zero = []
    classification_layer = getattr(net.module, '_'+node.name+'_classification')
    if topks:
        for prot in topks.keys():
            found = False
            for (i_id, score) in topks[prot]:
                if score > 0.1:
                    found = True
            if not found:
                torch.nn.init.zeros_(classification_layer.weight[:,prot])
                set_to_zero.append(prot)
        print(f"Weights of prototypes of node {node.name}", set_to_zero, "are set to zero because it is never detected with similarity>0.1 in the training set", flush=True)

Visualizing prototypes for topk of node root ...


Collecting topk: 810it [00:13, 61.88it/s]  


16 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 810it [00:02, 299.58it/s] 

Abstained:  0





Weights of prototypes of node root [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 14, 15, 16, 17, 18] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+001+068 ...


Collecting topk: 780it [00:12, 63.79it/s]  


13 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 780it [00:02, 269.61it/s] 

Abstained:  0





Weights of prototypes of node 113+001+068 [0, 1, 2, 4, 5, 7, 9, 11, 14, 15, 17, 18, 19] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+060 ...


Collecting topk: 660it [00:11, 59.66it/s]


13 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 660it [00:02, 240.29it/s]

Abstained:  0





Weights of prototypes of node 113+060 [0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 13, 15, 19] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 001+052 ...


Collecting topk: 90it [00:02, 36.83it/s]   


0 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 90it [00:04, 21.89it/s]   

Abstained:  0





Weights of prototypes of node 001+052 [] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+187 ...


Collecting topk: 570it [00:09, 57.38it/s]


11 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 570it [00:03, 182.37it/s]

Abstained:  0





Weights of prototypes of node 113+187 [1, 3, 5, 7, 10, 11, 15, 16, 17, 18, 19] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 060+071 ...


Collecting topk: 90it [00:02, 36.71it/s]   


4 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 90it [00:03, 25.57it/s]   

Abstained:  0





Weights of prototypes of node 060+071 [0, 7, 13, 14] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 001+033 ...


Collecting topk: 100% 60/60 [00:02<00:00, 29.12it/s]


7 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 18.61it/s]

Abstained:  0





Weights of prototypes of node 001+033 [1, 8, 9, 11, 13, 14, 17] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+037 ...


Collecting topk: 510it [00:08, 58.58it/s]


0 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 510it [00:04, 112.17it/s]

Abstained:  0





Weights of prototypes of node 113+037 [] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 187+079 ...


Collecting topk: 100% 60/60 [00:02<00:00, 28.25it/s]


3 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 16.18it/s]

Abstained:  0





Weights of prototypes of node 187+079 [7, 12, 16] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 060+143 ...


Collecting topk: 100% 60/60 [00:02<00:00, 29.29it/s]


2 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 15.87it/s]

Abstained:  0





Weights of prototypes of node 060+143 [1, 17] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+030 ...


Collecting topk: 450it [00:07, 56.82it/s]


4 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 450it [00:03, 113.13it/s]

Abstained:  0





Weights of prototypes of node 113+030 [1, 6, 13, 18] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 037+077 ...


Collecting topk: 100% 60/60 [00:02<00:00, 29.82it/s]


0 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:04<00:00, 14.85it/s]

Abstained:  0





Weights of prototypes of node 037+077 [] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+085 ...


Collecting topk: 390it [00:06, 56.91it/s]  


6 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 390it [00:03, 105.40it/s] 

Abstained:  0





Weights of prototypes of node 113+085 [3, 6, 10, 13, 14, 18] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 030+156 ...


Collecting topk: 100% 60/60 [00:02<00:00, 29.52it/s]


1 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 15.17it/s]

Abstained:  0





Weights of prototypes of node 030+156 [16] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+194 ...


Collecting topk: 360it [00:06, 56.89it/s]


5 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 360it [00:03, 96.19it/s]

Abstained:  0





Weights of prototypes of node 113+194 [1, 3, 4, 7, 12] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+118 ...


Collecting topk: 300it [00:05, 52.70it/s]  


4 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 300it [00:03, 80.12it/s]  

Abstained:  0





Weights of prototypes of node 113+118 [3, 9, 11, 13] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 194+019 ...


Collecting topk: 100% 60/60 [00:02<00:00, 29.20it/s]


3 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 16.16it/s]

Abstained:  0





Weights of prototypes of node 194+019 [2, 8, 19] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+034 ...


Collecting topk: 270it [00:05, 52.82it/s]  


4 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 270it [00:03, 68.79it/s]  

Abstained:  0





Weights of prototypes of node 113+034 [4, 12, 15, 17] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+016 ...


Collecting topk: 240it [00:04, 49.32it/s]  


9 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 240it [00:03, 77.00it/s]  

Abstained:  0





Weights of prototypes of node 113+016 [0, 3, 4, 7, 10, 11, 13, 14, 18] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+165 ...


Collecting topk: 210it [00:04, 50.02it/s]


2 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 210it [00:03, 52.68it/s]

Abstained:  0





Weights of prototypes of node 113+165 [10, 19] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+011 ...


Collecting topk: 120it [00:02, 41.64it/s]


6 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 120it [00:03, 35.06it/s]

Abstained:  0





Weights of prototypes of node 113+011 [6, 8, 9, 11, 16, 19] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 165+181 ...


Collecting topk: 90it [00:02, 35.93it/s]   


4 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 90it [00:03, 25.79it/s]   

Abstained:  0





Weights of prototypes of node 165+181 [8, 9, 13, 14] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 113+122 ...


Collecting topk: 100% 60/60 [00:02<00:00, 29.33it/s]


7 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 19.45it/s]

Abstained:  0





Weights of prototypes of node 113+122 [1, 5, 6, 8, 11, 12, 18] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 011+097 ...


Collecting topk: 100% 60/60 [00:02<00:00, 28.57it/s]


4 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 16.96it/s]

Abstained:  0





Weights of prototypes of node 011+097 [0, 1, 8, 10] are set to zero because it is never detected with similarity>0.1 in the training set
Visualizing prototypes for topk of node 165+161 ...


Collecting topk: 100% 60/60 [00:02<00:00, 27.66it/s]


1 prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.


Visualizing topk: 100% 60/60 [00:03<00:00, 15.25it/s]

Abstained:  0





Weights of prototypes of node 165+161 [7] are set to zero because it is never detected with similarity>0.1 in the training set
