In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
import os
import torchvision.transforms as transforms
from pipnet.pipnet import PIPNet, get_network
from util.data import get_dataloaders
from util.vis_pipnet import get_img_coordinates
from util.func import get_patch_size
from util.eval_cub_csv import get_topk_cub
from PIL import ImageFont, Image, ImageDraw as D
from pipnet.train import test_pipnet, train_pipnet
from omegaconf import OmegaConf
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
from util.args import get_args, save_args, get_optimizer_nn
import wandb
from torchvision.datasets.folder import ImageFolder
import pdb
import math
print(torch.cuda.is_available())

!which python

True
/home/harishbabu/.conda/envs/hpnet1/bin/python


In [2]:
# run_path = '/home/harishbabu/projects/PIPNet/runs/004-CUB-27-imgnet_cnext26_img=224_nprotos=200'
# run_path = '/home/harishbabu/projects/PIPNet/runs/005-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/007-CUB-27-imgnet_cnext26_img=224_nprotos=50'
run_path = '/home/harishbabu/projects/PIPNet/runs/010-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'

if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []

args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)

ckpt_path = os.path.join(run_path, 'checkpoints', 'net_trained_last')
checkpoint = torch.load(ckpt_path, map_location=device)

In [3]:
if args.phylo_config:
    phylo_config = OmegaConf.load(args.phylo_config)

if args.phylo_config:
    # construct the phylo tree
    if phylo_config.phyloDistances_string == 'None':
        root = construct_phylo_tree(phylo_config.phylogeny_path)
        print('-'*25 + ' No discretization ' + '-'*25)
    else:
        root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
        print('-'*25 + ' Discretized ' + '-'*25)
else:
    # construct the tree (original hierarchy as described in the paper)
    root = Node("root")
    root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
    root.add_children_to('animal',['non_primate','primate'])
    root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
    root.add_children_to('primate',['capuchin','gibbon','orangutan'])
    root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
    root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
    root.add_children_to('weapon',['assault_rifle','rifle'])
    # flat root
    # root.add_children(['scuba_diver','African_elephant','giant_panda','lion','capuchin','gibbon','orangutan','ambulance','pickup','sports_car','laptop','sandal','wine_bottle','assault_rifle','rifle'])
root.assign_all_descendents()

------------------------- No discretization -------------------------


In [4]:
# Obtain the dataset and dataloaders
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device)
if len(classes)<=20:
    if args.validation_size == 0.:
        print("Classes: ", testloader.dataset.class_to_idx, flush=True)
    else:
        print("Classes: ", str(classes), flush=True)

Num classes (k) =  27 ['cub_001_Black_footed_Albatross', 'cub_011_Rusty_Blackbird', 'cub_016_Painted_Bunting', 'cub_019_Gray_Catbird', 'cub_030_Fish_Crow'] etc.


In [5]:
# Create a convolutional network based on arguments and add 1x1 conv layer
feature_net, add_on_layers, pool_layer, classification_layers, num_prototypes = get_network(len(classes), args, root=root)
   
# Create a PIP-Net
net = PIPNet(num_classes=len(classes),
                    num_prototypes=num_prototypes,
                    feature_net = feature_net,
                    args = args,
                    add_on_layers = add_on_layers,
                    pool_layer = pool_layer,
                    classification_layers = classification_layers,
                    num_parent_nodes = len(root.nodes_with_children()),
                    root = root
                    )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
criterion = nn.NLLLoss(reduction='mean').to(device)

Number of prototypes:  20


In [40]:
# print(root)
epoch = 0
# run = wandb.init(project="pipnet", name=os.path.basename(args.log_dir), config=vars(args), reinit=False)
info = test_pipnet(net, testloader, criterion, epoch, device, progress_prefix= 'Test Epoch', wandb_logging=False, wandb_log_subdir = 'test')
print('test', info['fine_accuracy'])
info = test_pipnet(net, trainloader, criterion, epoch, device, progress_prefix= 'Train Epoch', wandb_logging=False, wandb_log_subdir = 'train')
print('train', info['fine_accuracy'])

Test Epoch0: 100% 13/13 [00:07<00:00,  1.69it/s, L:25.841,LC:0.018, LA:0.02, LT:0.556]

	Fine accuracy: 0.91
	Node name: root, acc: 99.49, f1:99.48, samples: 1562, 113+001+068=1500/1502=1.0, cub_090_Red_breasted_Merganser=54/60=0.9
	Node name: 113+001+068, acc: 98.54, f1:98.52, samples: 1502, 113+060=1258/1264=1.0, 001+052=162/178=0.91, cub_068_Ruby_throated_Hummingbird=60/60=1.0
	Node name: 113+060, acc: 99.21, f1:99.2, samples: 1264, 113+187=1084/1086=1.0, 060+071=170/178=0.96
	Node name: 001+052, acc: 98.88, f1:98.88, samples: 178, 001+033=116/118=0.98, cub_052_Pied_billed_Grebe=60/60=1.0
	Node name: 113+187, acc: 99.08, f1:99.09, samples: 1086, 113+037=978/986=0.99, 187+079=98/100=0.98
	Node name: 060+071, acc: 94.38, f1:94.36, samples: 178, 060+143=114/118=0.97, cub_071_Long_tailed_Jaeger=54/60=0.9
	Node name: 001+033, acc: 98.31, f1:98.31, samples: 118, cub_001_Black_footed_Albatross=58/60=0.97, cub_033_Yellow_billed_Cuckoo=58/58=1.0
	Node name: 113+037, acc: 99.59, f1:99.59, samples: 986, 113+030=868/868=1.0, 037+077=114/118=0.97
	Node name: 187+079, acc: 94.0, f1:


Train Epoch0: 100% 13/13 [00:08<00:00,  1.57it/s, L:5.325,LC:0.004, LA:0.03, LT:0.026]

	Fine accuracy: 0.99
	Node name: root, acc: 100.0, f1:100.0, samples: 1620, 113+001+068=1560/1560=1.0, cub_090_Red_breasted_Merganser=60/60=1.0
	Node name: 113+001+068, acc: 99.87, f1:99.87, samples: 1560, 113+060=1318/1320=1.0, 001+052=180/180=1.0, cub_068_Ruby_throated_Hummingbird=60/60=1.0
	Node name: 113+060, acc: 99.77, f1:99.77, samples: 1320, 113+187=1138/1140=1.0, 060+071=179/180=0.99
	Node name: 001+052, acc: 98.89, f1:98.88, samples: 180, 001+033=120/120=1.0, cub_052_Pied_billed_Grebe=58/60=0.97
	Node name: 113+187, acc: 100.0, f1:100.0, samples: 1140, 113+037=1020/1020=1.0, 187+079=120/120=1.0
	Node name: 060+071, acc: 98.89, f1:98.88, samples: 180, 060+143=120/120=1.0, cub_071_Long_tailed_Jaeger=58/60=0.97
	Node name: 001+033, acc: 97.5, f1:97.5, samples: 120, cub_001_Black_footed_Albatross=57/60=0.95, cub_033_Yellow_billed_Cuckoo=60/60=1.0
	Node name: 113+037, acc: 99.71, f1:99.7, samples: 1020, 113+030=900/900=1.0, 037+077=117/120=0.97
	Node name: 187+079, acc: 99.17, f1:




In [58]:
data_loader = testloader
# data_loader = trainloader_normal
node = root.get_node('113+011')

net.eval()

dataset = data_loader.dataset
while type(dataset) != ImageFolder:
    dataset = dataset.dataset
name2label = dataset.class_to_idx
label2name = {label:name for name, label in name2label.items()}

with torch.no_grad():
    for xs, ys1 in data_loader:
        batch_names = [label2name[y.item()] for y in ys1]
        children_idx = torch.tensor([name in node.descendents for name in batch_names])
        batch_names_coarsest = [node.closest_descendent_for(name).name for name in batch_names if name in node.descendents]
        node_y = torch.tensor([node.children_to_labels[name] for name in batch_names_coarsest]).cuda()
        if len(node_y) == 0:
                continue
        
        for image, target in zip(xs[children_idx], node_y):
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item(), 'gt', target.item())  
            
            pdb.set_trace()
        
#         not_children_idx = torch.logical_not(children_idx)
#         for image in xs[not_children_idx]:
#             proto_features, pooled, out = net(image.unsqueeze(0))
#             node_logits = out[node.name]
#             normalized_score = torch.log1p(node_logits**net.module._multiplier)
#             prob = torch.nn.functional.softmax(normalized_score,1)
#             print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item())   
    

0.9970056414604187 True pred 0 gt 0
> [0;32m/tmp/ipykernel_178025/3540772027.py[0m(22)[0;36m<module>[0;34m()[0m
[0;32m     20 [0;31m                [0;32mcontinue[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m        [0;32mfor[0m [0mimage[0m[0;34m,[0m [0mtarget[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mxs[0m[0;34m[[0m[0mchildren_idx[0m[0;34m][0m[0;34m,[0m [0mnode_y[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m            [0mproto_features[0m[0;34m,[0m [0mpooled[0m[0;34m,[0m [0mout[0m [0;34m=[0m [0mnet[0m[0;34m([0m[0mimage[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mnode_logits[0m [0;34m=[0m [0mout[0m[0;34m[[0m[0mnode[0m[0;34m.[0m[0mname[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> getattr(net, "_" + node.name + "_classification").sha

ipdb> c_weights[0].unsqueeze(0) * pooled[node.name]
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 3.3782,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0289, 0.0000, 0.0000, 0.0000,
         0.0000, 4.6081, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 5.7454, 0.0000, 0.0000,
         4.4652, 0.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0')
ipdb> pooled[node.name]
tensor([[1.2858e-01, 1.3992e-04, 1.7866e-02, 3.3596e-05, 4.1091e-05, 9.9386e-01,
         1.0000e+00, 3.9644e-01, 9.9908e-01, 9.8449e-01, 3.6535e-05, 4.9866e-01,
         1.0000e+00, 1.0000e+00, 5.5580e-01, 9.9982e-01, 9.9995e-01, 1.0000e+00,
         6.8652e-01, 9.1193e-01, 2.4491e-04, 8.9950e-01, 9.9998e-01, 9.8926e-01,
         5.1049e-05, 1.0000e+00, 8.7759e-01, 3.6892e-01, 9.9999e-01, 3.0368e-04,
         9.7374e-01, 9.9593e-01, 1.9587e-04, 7

In [None]:
c_weights = getattr(net.module, "_" + node.name + "_classification").weight
logit1 = torch.log1p(torch.sum(c_weights[0].unsqueeze(0) * pooled[node.name]) * net.module._multiplier)
logit2 = torch.log1p(torch.sum(c_weights[1].unsqueeze(0) * pooled[node.name]) * net.module._multiplier)
torch.nn.functional.softmax(torch.tensor([[logit1, logit2]]),1)
c_weights[0].unsqueeze(0) * pooled[node.name]

In [65]:
data_loader = testloader
# data_loader = trainloader_normal
node = root.get_node('113+011')

net.eval()

dataset = data_loader.dataset
while type(dataset) != ImageFolder:
    dataset = dataset.dataset
name2label = dataset.class_to_idx
label2name = {label:name for name, label in name2label.items()}

with torch.no_grad():
    for xs, ys1 in data_loader:
        batch_names = [label2name[y.item()] for y in ys1]
        children_idx = torch.tensor([name in node.descendents for name in batch_names])
        batch_names_coarsest = [node.closest_descendent_for(name).name for name in batch_names if name in node.descendents]
        node_y = torch.tensor([node.children_to_labels[name] for name in batch_names_coarsest]).cuda()
        if len(node_y) == 0:
                continue
        
#         for image, target in zip(xs[children_idx], node_y):
#             proto_features, pooled, out = net(image.unsqueeze(0))
#             node_logits = out[node.name]
#             normalized_score = torch.log1p(node_logits**net.module._multiplier)
#             prob = torch.nn.functional.softmax(normalized_score,1)
#             print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item(), 'gt', target.item())  
            
#             pdb.set_trace()
        
        not_children_idx = torch.logical_not(children_idx)
        for image in xs[not_children_idx]:
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item())  
            pdb.set_trace()
            
            
    

0.9602632522583008 False pred 1
> [0;32m/tmp/ipykernel_178025/241351276.py[0m(32)[0;36m<module>[0;34m()[0m
[0;32m     30 [0;31m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mnot_children_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mlogical_not[0m[0;34m([0m[0mchildren_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 32 [0;31m        [0;32mfor[0m [0mimage[0m [0;32min[0m [0mxs[0m[0;34m[[0m[0mnot_children_idx[0m[0;34m][0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     33 [0;31m            [0mproto_features[0m[0;34m,[0m [0mpooled[0m[0;34m,[0m [0mout[0m [0;34m=[0m [0mnet[0m[0;34m([0m[0mimage[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m            [0mnode_logits[0m [0;34m=[0m [0mout[0m[0;34m[[0m[0mnode[0m[0;34m.[0m[0mname[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> c_weights = getattr(net.module, "_" + node.n

In [64]:
from torchvision.transforms.functional import normalize
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
node = root.get_node('113+011')

net.eval()

with torch.no_grad():
    noise = torch.stack([normalize(torch.rand((3,224,224)),mean,std) for n in range(64)]).cuda()
    for image in noise:
        proto_features, pooled, out = net(image.unsqueeze(0))
        node_logits = out[node.name]
        normalized_score = torch.log1p(node_logits**net.module._multiplier)
        prob = torch.nn.functional.softmax(normalized_score,1)
        print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item())
        pdb.set_trace()


0.5 True pred 0
> [0;32m/tmp/ipykernel_178025/480639191.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0;32mwith[0m [0mtorch[0m[0;34m.[0m[0mno_grad[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m    [0mnoise[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mstack[0m[0;34m([0m[0;34m[[0m[0mnormalize[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0;34m([0m[0;36m3[0m[0;34m,[0m[0;36m224[0m[0;34m,[0m[0;36m224[0m[0;34m)[0m[0;34m)[0m[0;34m,[0m[0mmean[0m[0;34m,[0m[0mstd[0m[0;34m)[0m [0;32mfor[0m [0mn[0m [0;32min[0m [0mrange[0m[0;34m([0m[0;36m64[0m[0;34m)[0m[0;34m][0m[0;34m)[0m[0;34m.[0m[0mcuda[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m    [0;32mfor[0m [0mimage[0m [0;32min[0m [0mnoise[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m        [0mproto_features[0m[0;34m,[0m [0mpooled[0m[0;34m,[

In [7]:
train_iter = iter(trainloader)
x = next(train_iter)
y = next(train_iter)

In [11]:
x[2]

tensor([17, 26, 22, 17, 16, 20, 15, 11,  5,  8, 19, 12,  7,  8, 12, 12, 20, 22,
        22, 20, 18,  1, 23, 14, 11, 18,  2, 13, 21, 16, 16, 10, 18,  1, 18,  2,
         4,  1, 20,  6,  1, 15,  7, 11, 25,  4, 14,  4, 11, 13,  7, 20,  6,  6,
        10, 18,  1, 18,  2, 11, 22, 15,  6, 13])

In [12]:
y[2]

tensor([12,  0, 18,  6,  4, 11,  0,  5, 14, 12,  8, 25, 18,  0, 20, 21, 12, 25,
        26, 19, 12, 25, 23,  2, 15,  9, 24, 13, 15, 26, 10,  3,  2, 10, 26,  5,
        13, 23, 10, 24, 20,  5, 25,  9,  0, 14,  5, 19,  9, 14, 17, 21,  0, 11,
         1,  7, 18,  7, 25,  2, 10,  1, 10, 12])

In [6]:
data_loader = testloader
# data_loader = trainloader_normal
node = root.get_node('113+011')

net.eval()

dataset = data_loader.dataset
while type(dataset) != ImageFolder:
    dataset = dataset.dataset
name2label = dataset.class_to_idx
label2name = {label:name for name, label in name2label.items()}

with torch.no_grad():
    for xs, ys1 in data_loader:
        batch_names = [label2name[y.item()] for y in ys1]
        children_idx = torch.tensor([name in node.descendents for name in batch_names])
        batch_names_coarsest = [node.closest_descendent_for(name).name for name in batch_names if name in node.descendents]
        node_y = torch.tensor([node.children_to_labels[name] for name in batch_names_coarsest]).cuda()
        if len(node_y) == 0:
                continue
        
        for image, target in zip(xs[children_idx], node_y):
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item(), 'gt', target.item())  
            
            pdb.set_trace()
        
#         not_children_idx = torch.logical_not(children_idx)
#         for image in xs[not_children_idx]:
#             proto_features, pooled, out = net(image.unsqueeze(0))
#             node_logits = out[node.name]
#             normalized_score = torch.log1p(node_logits**net.module._multiplier)
#             prob = torch.nn.functional.softmax(normalized_score,1)
#             print(torch.max(prob).item(), torch.argmax(prob).item() == target.item(), 'pred', torch.argmax(prob).item())  
#             pdb.set_trace()
            
            
    

0.5 True pred 0 gt 0
> [0;32m/tmp/ipykernel_19088/60307923.py[0m(22)[0;36m<module>[0;34m()[0m
[0;32m     20 [0;31m                [0;32mcontinue[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m        [0;32mfor[0m [0mimage[0m[0;34m,[0m [0mtarget[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mxs[0m[0;34m[[0m[0mchildren_idx[0m[0;34m][0m[0;34m,[0m [0mnode_y[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m            [0mproto_features[0m[0;34m,[0m [0mpooled[0m[0;34m,[0m [0mout[0m [0;34m=[0m [0mnet[0m[0;34m([0m[0mimage[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mnode_logits[0m [0;34m=[0m [0mout[0m[0;34m[[0m[0mnode[0m[0;34m.[0m[0mname[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> q
