In [2]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
import os
import torchvision.transforms as transforms
from pipnet.pipnet import PIPNet, get_network
from util.data import get_dataloaders
from util.vis_pipnet import get_img_coordinates
from util.func import get_patch_size
from util.eval_cub_csv import get_topk_cub
from PIL import ImageFont, Image, ImageDraw as D
from pipnet.train import test_pipnet, train_pipnet
from omegaconf import OmegaConf
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
from util.args import get_args, save_args, get_optimizer_nn
import wandb
from torchvision.datasets.folder import ImageFolder
import pdb
import math
print(torch.cuda.is_available())

!which python

True
/home/harishbabu/.conda/envs/hpnet1/bin/python


In [35]:
# run_path = '/home/harishbabu/projects/PIPNet/runs/004-CUB-27-imgnet_cnext26_img=224_nprotos=200'
# run_path = '/home/harishbabu/projects/PIPNet/runs/005-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/007-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/009-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/010-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/012-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
run_path = '/home/harishbabu/projects/PIPNet/runs/013-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'

if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []

args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)
args.OOD_dataset = 'CUB-163-OOD-imgnet-224'

ckpt_path = os.path.join(run_path, 'checkpoints', 'net_trained_last')
checkpoint = torch.load(ckpt_path, map_location=device)

In [36]:
if args.phylo_config:
    phylo_config = OmegaConf.load(args.phylo_config)

if args.phylo_config:
    # construct the phylo tree
    if phylo_config.phyloDistances_string == 'None':
        root = construct_phylo_tree(phylo_config.phylogeny_path)
        print('-'*25 + ' No discretization ' + '-'*25)
    else:
        root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
        print('-'*25 + ' Discretized ' + '-'*25)
else:
    # construct the tree (original hierarchy as described in the paper)
    root = Node("root")
    root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
    root.add_children_to('animal',['non_primate','primate'])
    root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
    root.add_children_to('primate',['capuchin','gibbon','orangutan'])
    root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
    root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
    root.add_children_to('weapon',['assault_rifle','rifle'])
    # flat root
    # root.add_children(['scuba_diver','African_elephant','giant_panda','lion','capuchin','gibbon','orangutan','ambulance','pickup','sports_car','laptop','sandal','wine_bottle','assault_rifle','rifle'])
root.assign_all_descendents()

------------------------- No discretization -------------------------


In [37]:
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device, OOD=False)
trainloader_OOD, trainloader_pretraining_OOD, trainloader_normal_OOD, trainloader_normal_augment_OOD, projectloader_OOD, testloader_OOD, test_projectloader_OOD, _ = get_dataloaders(args, device, OOD=True)

Num classes (k) =  27 ['cub_001_Black_footed_Albatross', 'cub_011_Rusty_Blackbird', 'cub_016_Painted_Bunting', 'cub_019_Gray_Catbird', 'cub_030_Fish_Crow'] etc.
Dropping 21 samples from trainloader_pretraining
Num classes (k) =  163 ['cub_002_Laysan_Albatross', 'cub_003_Sooty_Albatross', 'cub_004_Groove_billed_Ani', 'cub_005_Crested_Auklet', 'cub_006_Least_Auklet'] etc.


In [38]:
# Create a convolutional network based on arguments and add 1x1 conv layer
feature_net, add_on_layers, pool_layer, classification_layers, num_prototypes = get_network(len(classes), args, root=root)
   
# Create a PIP-Net
net = PIPNet(num_classes=len(classes),
                    num_prototypes=num_prototypes,
                    feature_net = feature_net,
                    args = args,
                    add_on_layers = add_on_layers,
                    pool_layer = pool_layer,
                    classification_layers = classification_layers,
                    num_parent_nodes = len(root.nodes_with_children()),
                    root = root
                    )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
criterion = nn.NLLLoss(reduction='mean').to(device)

Number of prototypes:  20


In [39]:
data_loader = testloader
data_loader_OOD = testloader_OOD
# data_loader = trainloader_normal
data_loader.shuffle = False
data_loader_OOD.shuffle = False

net.eval()

dataset = data_loader.dataset
while type(dataset) != ImageFolder:
    dataset = dataset.dataset
name2label = dataset.class_to_idx
label2name = {label:name for name, label in name2label.items()}

def confidences(node):
    print('-'*25, node.name, '-'*25)
    with torch.no_grad():
        non_child_images = {}
        ID_images = {}
        for xs, ys1 in data_loader:        
            for i in range(ys1.shape[0]):
                y = ys1[i].item()
                if (label2name[y] not in node.descendents) and (y not in non_child_images):
                    non_child_images[y] = xs[i]
                if (label2name[y] in node.descendents) and (y not in ID_images):
                    ID_images[y] = xs[i]

        OOD_images = {}
        for xs, ys1 in data_loader_OOD:
            for i in range(ys1.shape[0]):
                y = ys1[i].item()
                if (y not in OOD_images):
                    OOD_images[y] = xs[i]

        non_child_confidences = []
        for y, image in non_child_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            non_child_confidences.append(torch.max(prob).item())
        # print('non_child_confidences', non_child_confidences, '\n')
        print('non_child_confidences', 'mean:', np.mean(non_child_confidences), 'std:', np.std(non_child_confidences))

        OOD_confidences = []
        for y, image in OOD_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            OOD_confidences.append(torch.max(prob).item())
        # print('OOD_confidences', OOD_confidences, '\n')
        print('OOD_confidences', 'mean:', np.mean(OOD_confidences), 'std:', np.std(OOD_confidences))

        ID_confidences = []
        for y, image in ID_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            ID_confidences.append(torch.max(prob).item())
        # print('ID_confidences', ID_confidences, '\n')
        print('ID_confidences', 'mean:', np.mean(ID_confidences), 'std:', np.std(ID_confidences))
        
        print('\n')
                

# node = root.get_node('113+011')
for node in root.nodes_with_children():
    confidences(node)
   
    

------------------------- root -------------------------
non_child_confidences mean: nan std: nan
OOD_confidences mean: 0.5472195089960391 std: 0.12615311142344338
ID_confidences mean: 0.9452682336171468 std: 0.10516166460434276


------------------------- 113+001+068 -------------------------
non_child_confidences mean: 0.3333337604999542 std: 0.0
OOD_confidences mean: 0.418114933865202 std: 0.20068160482975095
ID_confidences mean: 0.8485037157168756 std: 0.2570069081898227


------------------------- 113+060 -------------------------
non_child_confidences mean: 0.5000025868415833 std: 2.5270120052787983e-06
OOD_confidences mean: 0.5459567096335757 std: 0.12978992823200067
ID_confidences mean: 0.9083510149608959 std: 0.15305573435113737


------------------------- 001+052 -------------------------
non_child_confidences mean: 0.5003292908271154 std: 0.001042086826272856
OOD_confidences mean: 0.5033330961239119 std: 0.03570649191834405
ID_confidences mean: 0.7995758454004923 std: 0.2118