In [20]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
import os
import torchvision.transforms as transforms
from pipnet.pipnet import PIPNet, get_network
from util.data import get_dataloaders
from util.vis_pipnet import get_img_coordinates
from util.func import get_patch_size
from util.eval_cub_csv import get_topk_cub
from PIL import ImageFont, Image, ImageDraw as D
from pipnet.train import test_pipnet, train_pipnet
from omegaconf import OmegaConf
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
from util.args import get_args, save_args, get_optimizer_nn
import wandb
from torchvision.datasets.folder import ImageFolder
import pdb
import math
print(torch.cuda.is_available())

!which python

True
/home/harishbabu/.conda/envs/hpnet1/bin/python


In [12]:
# run_path = '/home/harishbabu/projects/PIPNet/runs/004-CUB-27-imgnet_cnext26_img=224_nprotos=200'
# run_path = '/home/harishbabu/projects/PIPNet/runs/005-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/007-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/009-CUB-27-imgnet_cnext26_img=224_nprotos=50'
# run_path = '/home/harishbabu/projects/PIPNet/runs/010-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/012-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/013-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
# run_path = '/home/harishbabu/projects/PIPNet/runs/022-CUB-27-imgnet_OOD_cnext26_img=224_nprotos=20'
run_path = '/home/harishbabu/projects/PIPNet/runs/022-CUB-27-imgnet_cnext26_img=224_nprotos=20'

if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []

args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)
args.OOD_dataset = 'CUB-163-OOD-imgnet-224'

ckpt_path = os.path.join(run_path, 'checkpoints', 'net_trained_last')
checkpoint = torch.load(ckpt_path, map_location=device)

In [13]:
if args.phylo_config:
    phylo_config = OmegaConf.load(args.phylo_config)

if args.phylo_config:
    # construct the phylo tree
    if phylo_config.phyloDistances_string == 'None':
        root = construct_phylo_tree(phylo_config.phylogeny_path)
        print('-'*25 + ' No discretization ' + '-'*25)
    else:
        root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
        print('-'*25 + ' Discretized ' + '-'*25)
else:
    # construct the tree (original hierarchy as described in the paper)
    root = Node("root")
    root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
    root.add_children_to('animal',['non_primate','primate'])
    root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
    root.add_children_to('primate',['capuchin','gibbon','orangutan'])
    root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
    root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
    root.add_children_to('weapon',['assault_rifle','rifle'])
    # flat root
    # root.add_children(['scuba_diver','African_elephant','giant_panda','lion','capuchin','gibbon','orangutan','ambulance','pickup','sports_car','laptop','sandal','wine_bottle','assault_rifle','rifle'])
root.assign_all_descendents()

------------------------- No discretization -------------------------


In [14]:
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device, OOD=False)
trainloader_OOD, trainloader_pretraining_OOD, trainloader_normal_OOD, trainloader_normal_augment_OOD, projectloader_OOD, testloader_OOD, test_projectloader_OOD, _ = get_dataloaders(args, device, OOD=True)

Num classes (k) =  27 ['cub_001_Black_footed_Albatross', 'cub_011_Rusty_Blackbird', 'cub_016_Painted_Bunting', 'cub_019_Gray_Catbird', 'cub_030_Fish_Crow'] etc.
Dropping 21 samples from trainloader_pretraining
Num classes (k) =  163 ['cub_002_Laysan_Albatross', 'cub_003_Sooty_Albatross', 'cub_004_Groove_billed_Ani', 'cub_005_Crested_Auklet', 'cub_006_Least_Auklet'] etc.


In [15]:
# Create a convolutional network based on arguments and add 1x1 conv layer
feature_net, add_on_layers, pool_layer, classification_layers, num_prototypes = get_network(len(classes), args, root=root)
   
# Create a PIP-Net
net = PIPNet(num_classes=len(classes),
                    num_prototypes=num_prototypes,
                    feature_net = feature_net,
                    args = args,
                    add_on_layers = add_on_layers,
                    pool_layer = pool_layer,
                    classification_layers = classification_layers,
                    num_parent_nodes = len(root.nodes_with_children()),
                    root = root
                    )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
criterion = nn.NLLLoss(reduction='mean').to(device)

Number of prototypes:  20


# Confidence

In [16]:
data_loader = testloader
data_loader_OOD = testloader_OOD
# data_loader = trainloader_normal
data_loader.shuffle = False
data_loader_OOD.shuffle = False

net.eval()

dataset = data_loader.dataset
while type(dataset) != ImageFolder:
    dataset = dataset.dataset
name2label = dataset.class_to_idx
label2name = {label:name for name, label in name2label.items()}

def confidences(node):
    print('-'*25, node.name, '-'*25)
    with torch.no_grad():
        non_child_images = {}
        ID_images = {}
        for xs, ys1 in data_loader:        
            for i in range(ys1.shape[0]):
                y = ys1[i].item()
                if (label2name[y] not in node.descendents) and (y not in non_child_images):
                    non_child_images[y] = xs[i]
                if (label2name[y] in node.descendents) and (y not in ID_images):
                    ID_images[y] = xs[i]

        OOD_images = {}
        for xs, ys1 in data_loader_OOD:
            for i in range(ys1.shape[0]):
                y = ys1[i].item()
                if (y not in OOD_images):
                    OOD_images[y] = xs[i]

        non_child_confidences = []
        for y, image in non_child_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            non_child_confidences.append(torch.max(prob).item())
        # print('non_child_confidences', non_child_confidences, '\n')
        print('non_child_confidences', 'mean:', np.mean(non_child_confidences), 'std:', np.std(non_child_confidences))

        OOD_confidences = []
        for y, image in OOD_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            OOD_confidences.append(torch.max(prob).item())
        # print('OOD_confidences', OOD_confidences, '\n')
        print('OOD_confidences', 'mean:', np.mean(OOD_confidences), 'std:', np.std(OOD_confidences))

        ID_confidences = []
        for y, image in ID_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
            normalized_score = torch.log1p(node_logits**net.module._multiplier)
            prob = torch.nn.functional.softmax(normalized_score,1)
            ID_confidences.append(torch.max(prob).item())
        # print('ID_confidences', ID_confidences, '\n')
        print('ID_confidences', 'mean:', np.mean(ID_confidences), 'std:', np.std(ID_confidences))
        
        print('\n')
                

# node = root.get_node('113+011')
for node in root.nodes_with_children():
    confidences(node)
   
    

------------------------- root -------------------------
non_child_confidences mean: nan std: nan
OOD_confidences mean: 0.9683977815271155 std: 0.08407312034167763
ID_confidences mean: 0.9756553813263222 std: 0.08294541915407387


------------------------- 113+001+068 -------------------------
non_child_confidences mean: 0.8014857769012451 std: 0.0
OOD_confidences mean: 0.9321270318119073 std: 0.13849110488238486
ID_confidences mean: 0.9082193133922724 std: 0.20898776455961346


------------------------- 113+060 -------------------------
non_child_confidences mean: 0.9584035396575927 std: 0.016433675293436416
OOD_confidences mean: 0.9670461997664048 std: 0.07794755846479542
ID_confidences mean: 0.9760993881659075 std: 0.0887684807139982


------------------------- 001+052 -------------------------
non_child_confidences mean: 0.8680721173683802 std: 0.13703889879862605
OOD_confidences mean: 0.8468779175559435 std: 0.14109903076196142
ID_confidences mean: 0.9882462620735168 std: 0.007019

# Logit

In [7]:
data_loader = testloader
data_loader_OOD = testloader_OOD
# data_loader = trainloader_normal
data_loader.shuffle = False
data_loader_OOD.shuffle = False

net.eval()

dataset = data_loader.dataset
while type(dataset) != ImageFolder:
    dataset = dataset.dataset
name2label = dataset.class_to_idx
label2name = {label:name for name, label in name2label.items()}

def confidences(node):
    print('-'*25, node.name, '-'*25)
    with torch.no_grad():
        non_child_images = {}
        ID_images = {}
        for xs, ys1 in data_loader:        
            for i in range(ys1.shape[0]):
                y = ys1[i].item()
                if (label2name[y] not in node.descendents) and (y not in non_child_images):
                    non_child_images[y] = xs[i]
                if (label2name[y] in node.descendents) and (y not in ID_images):
                    ID_images[y] = xs[i]

        OOD_images = {}
        for xs, ys1 in data_loader_OOD:
            for i in range(ys1.shape[0]):
                y = ys1[i].item()
                if (y not in OOD_images):
                    OOD_images[y] = xs[i]

        non_child_confidences = []
        for y, image in non_child_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
#             normalized_score = torch.log1p(node_logits**net.module._multiplier)
#             prob = torch.nn.functional.softmax(normalized_score,1)
            non_child_confidences.append(torch.max(node_logits).item())
        # print('non_child_confidences', non_child_confidences, '\n')
        print('non_child_max_logits', 'mean:', np.mean(non_child_confidences), 'std:', np.std(non_child_confidences))

        OOD_confidences = []
        for y, image in OOD_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
#             normalized_score = torch.log1p(node_logits**net.module._multiplier)
#             prob = torch.nn.functional.softmax(normalized_score,1)
            OOD_confidences.append(torch.max(node_logits).item())
        # print('OOD_confidences', OOD_confidences, '\n')
        print('OOD_max_logits', 'mean:', np.mean(OOD_confidences), 'std:', np.std(OOD_confidences))

        ID_confidences = []
        for y, image in ID_images.items():
            proto_features, pooled, out = net(image.unsqueeze(0))
            node_logits = out[node.name]
#             normalized_score = torch.log1p(node_logits**net.module._multiplier)
#             prob = torch.nn.functional.softmax(normalized_score,1)
            ID_confidences.append(torch.max(node_logits).item())
        # print('ID_confidences', ID_confidences, '\n')
        print('ID_max_logits', 'mean:', np.mean(ID_confidences), 'std:', np.std(ID_confidences))
        
        print('\n')
                

# node = root.get_node('113+011')
for node in root.nodes_with_children():
    confidences(node)
   
    

------------------------- root -------------------------
non_child_max_logits mean: nan std: nan
OOD_max_logits mean: 0.6082209549682498 std: 1.592678779783133
ID_max_logits mean: 6.032765225265865 std: 4.764022885239978


------------------------- 113+001+068 -------------------------
non_child_max_logits mean: 0.0005412555183283985 std: 0.0
OOD_max_logits mean: 1.0353187390026761 std: 2.5701152959573905
ID_max_logits mean: 9.50496085258559 std: 6.211165565225364


------------------------- 113+060 -------------------------
non_child_max_logits mean: 0.025293122371658684 std: 0.04379857865999187
OOD_max_logits mean: 0.6131174176409298 std: 1.7258463005296325
ID_max_logits mean: 6.395429765572771 std: 5.211234176376284


------------------------- 001+052 -------------------------
non_child_max_logits mean: 0.10528445018765827 std: 0.1596867377259321
OOD_max_logits mean: 0.23528782221948602 std: 0.838539883743166
ID_max_logits mean: 10.873497009277344 std: 3.2598831119939735


---------

In [11]:
print(getattr(net.module, "_113+165_classification").bias)

None


In [22]:
import torchvision
import torch
import torch.nn as nn

if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []
    
model = torchvision.models.resnet18()
    
model = model.to(device=device)
model = nn.DataParallel(model, device_ids = device_ids)   


In [24]:
model.module.layer2[0].downsample[0].weight.shape

torch.Size([128, 64, 1, 1])

In [28]:
model
model.module.layer2[0].downsample[0]

Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)

In [30]:
model
# model.module.layer1#[0].conv1

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track