In [None]:
import torch.nn as nn
import torch
import sys, os
import random
import csv
import pandas as pd
import numpy as np
from shutil import copy
import matplotlib.pyplot as plt
from copy import deepcopy
from omegaconf import OmegaConf
import shutil
import pickle
import random
from PIL import Image
from tqdm import tqdm
from torchvision.datasets.folder import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision
import torch.nn.functional as F
import pdb
from collections import defaultdict
import ntpath

from util.func import get_patch_size
from hcompnet.model import HComPNet, get_network
from util.log import Log
from util.args import get_args, save_args, get_optimizer_nn
from util.data import get_dataloaders
from util.func import init_weights_xavier
from util.node import Node
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
from util.func import get_patch_size
from util.evaluation import get_topk_cub_nodewise, eval_prototypes_cub_parts_csv_nodewise_maxmin

In [None]:
# Set path to the experiment
run_path = 'runs/experiment'

# Load model

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []
args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)

# ------------ Load the phylogeny tree ------------
phylo_config = OmegaConf.load(args.phylo_config)
if phylo_config.phyloDistances_string == 'None':
    root = construct_phylo_tree(phylo_config.phylogeny_path)
    print('-'*25 + ' No discretization ' + '-'*25)
else:
    root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
    print('-'*25 + ' Discretized ' + '-'*25)
root.assign_all_descendents()
for node in root.nodes_with_children():
    node.set_num_protos(num_protos_per_descendant=args.num_protos_per_descendant,\
                        num_protos_per_child=args.num_protos_per_child,\
                        min_protos_per_child=args.min_protos_per_child)
    
# ------------ Load the train and test datasets ------------
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device)

# ------------ Load the model checkpoint ------------
ckpt_file_name = 'net_trained_last'
epoch = ckpt_file_name.split('_')[-1]
ckpt_path = os.path.join(run_path, 'checkpoints', ckpt_file_name)
checkpoint = torch.load(ckpt_path, map_location=device)
feature_net, add_on_layers, pool_layer, classification_layers = get_network(args, root=root)
net = HComPNet(feature_net = feature_net,
                args = args,
                add_on_layers = add_on_layers,
                pool_layer = pool_layer,
                classification_layers = classification_layers,
                num_parent_nodes = len(root.nodes_with_children()),
                root = root)
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
# Forward one batch through the backbone to get the latent output size
with torch.no_grad():
    xs1, _, _ = next(iter(trainloader))
    xs1 = xs1.to(device)
    _, proto_features, _, _ = net(xs1)
    wshape = proto_features['root'].shape[-1]
    args.wshape = wshape #needed for calculating image patch size
    print("Output shape: ", proto_features['root'].shape, flush=True)
print(args.wshape)

# Helper functions

In [None]:
# convert latent location to coordinates of image patch
def get_img_coordinates(img_size, softmaxes_shape, patchsize, skip, h_idx, w_idx):

    w_idx = w_idx.item() if torch.is_tensor(w_idx) else w_idx
    h_idx = h_idx.item() if torch.is_tensor(h_idx) else h_idx
    
    # in case latent output size is 26x26. For convnext with smaller strides. 
    if softmaxes_shape[1] == 26 and softmaxes_shape[2] == 26:
        #Since the outer latent patches have a smaller receptive field, skip size is set to 4 for the first and last patch. 8 for rest.
        h_coor_min = max(0,(h_idx-1)*skip+4)
        if h_idx < softmaxes_shape[-1]-1:
            h_coor_max = h_coor_min + patchsize
        else:
            h_coor_min -= 4
            h_coor_max = h_coor_min + patchsize
            
        w_coor_min = max(0,(w_idx-1)*skip+4)
        if w_idx < softmaxes_shape[-1]-1:
            w_coor_max = w_coor_min + patchsize
        else:
            w_coor_min -= 4
            w_coor_max = w_coor_min + patchsize
        
    else:
        h_coor_min = h_idx*skip
        h_coor_max = min(img_size, h_idx*skip+patchsize)
        w_coor_min = w_idx*skip
        w_coor_max = min(img_size, w_idx*skip+patchsize)  
    
    if h_idx == softmaxes_shape[1]-1:
        h_coor_max = img_size
    if w_idx == softmaxes_shape[2] -1:
        w_coor_max = img_size
    if h_coor_max == img_size:
        h_coor_min = img_size-patchsize
    if w_coor_max == img_size:
        w_coor_min = img_size-patchsize

    return h_coor_min, h_coor_max, w_coor_min, w_coor_max


def unshuffle_dataloader(dataloader, batch_size=1):
    if type(dataloader.dataset) == ImageFolder:
        dataset = dataloader.dataset
    else:
        dataset = dataloader.dataset.dataset#.dataset
    new_dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=dataloader.num_workers,
        pin_memory=dataloader.pin_memory,
        drop_last=dataloader.drop_last,
        timeout=dataloader.timeout,
        worker_init_fn=dataloader.worker_init_fn,
        multiprocessing_context=dataloader.multiprocessing_context,
        generator=dataloader.generator,
        prefetch_factor=dataloader.prefetch_factor,
        persistent_workers=dataloader.persistent_workers
    )
    return new_dataloader

# Find subtree root - only for finding does not affect the run, use the value found here in the visualization block

In [None]:
# if only interested in a subtree, add some leaf nodes of the sub tree to the list
# here we find the root of the smallest subtree that contains all the leaves

leaf_descendents = set(['cub_052_Pied_billed_Grebe', 'cub_004_Groove_billed_Ani'])
subtree_root = root
for node in root.nodes_with_children():
    if leaf_descendents.issubset(node.leaf_descendents) and (len(node.leaf_descendents) < len(subtree_root.leaf_descendents)):
        subtree_root = node
print(subtree_root.name)

# Hyperparameters

In [None]:
TOPK = 10
maindataloader = testloader # projectloader, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader
subtree_root = root#.get_node('024+051')

# Setup data

In [None]:
cub_meta_path = "" # Update the path for CUB meta file path
part_locs_file = os.path.join(cub_meta_path, 'parts', 'part_locs_normalized_after_cropped_after_padded.txt')
images_file = os.path.join(cub_meta_path, 'images_cub.txt')
NUM_PARTS = 15

# Read the image index to filename mapping
img_filename_to_index = {} # image filename to image index
with open(images_file, 'r') as file:
    for line in file:
        index, filename = line.strip().split()
        img_filename = ntpath.basename(filename)
        img_filename_to_index[img_filename] = int(index)

# Load part locations
image_part_locs = defaultdict(list)
with open(part_locs_file, 'r') as file:
    for line in file:
        parts = line.strip().split()
        image_index, part_id, x, y, visible = int(parts[0]), int(parts[1]), float(parts[2]), float(parts[3]), bool(float(parts[4]))
        image_part_locs[image_index].append((part_id, x, y, visible))

parts_name_path = os.path.join(cub_meta_path, 'parts', 'parts.txt')
imgs_id_path = os.path.join(cub_meta_path, 'parts', 'images_cub.txt')
maindataloader = unshuffle_dataloader(maindataloader, batch_size=1)
print(maindataloader.batch_size)

# Calculate part purity

In [None]:
list_csvfile_topk, list_node_wise_df, dict_node_wise_df = get_topk_cub_nodewise(net, root, maindataloader, \
                                                                                 TOPK, str(epoch), device, args)

node_wise_purity = []
node_wise_purity_of_unmasked = []
node_wise_purity_of_masked = []
for csvfile_topk, node in zip(list_csvfile_topk, root.nodes_with_children()):

    if node.name not in subtree_root.descendents:
        print('Skipping node', node.name)
        continue
        
    node_purity, max_presence_purity = eval_prototypes_cub_parts_csv_nodewise_maxmin(node, csvfile_topk, part_locs_file, parts_name_path, \
                              imgs_id_path, 'projectloader_topk_'+str(epoch), args, desc_threshold=0.2)
    node_wise_purity.append(node_purity)
    proto_presence = getattr(net.module, '_'+node.name+'_proto_presence')
    node_wise_purity_of_unmasked.append(np.mean([max_presence_purity[p] for p in max_presence_purity if (proto_presence[int(p), 0] < proto_presence[int(p), 1])]))
    node_wise_purity_of_masked.append(np.mean([max_presence_purity[p] for p in max_presence_purity if (proto_presence[int(p), 0] > proto_presence[int(p), 1])]))

### Part purity

In [None]:
print('Part purity including over-specific prototypes')
print('Mean:', np.nanmean(node_wise_purity), 'Std:', np.nanstd(node_wise_purity))

print('Part purity excluding over-specific prototypes')
print('Mean:', np.nanmean(node_wise_purity_of_unmasked), 'Std:', np.nanstd(node_wise_purity_of_unmasked))

print('Part purity of excluded prototypes')
print('Mean:', np.nanmean(node_wise_purity_of_masked), 'Std:', np.nanstd(node_wise_purity_of_masked))

# Ratio of good protos / Total protos

In [None]:

total_relevant_protos = 0.
total_good_protos = 0.

with torch.no_grad():
    for node in root.nodes_with_children():
        label_to_children = {v: k for k, v in node.children_to_labels.items()}
        classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
        proto_presence = getattr(net.module, '_'+node.name+'_proto_presence')
        proto_presence = F.gumbel_softmax(proto_presence, tau=0.5, hard=True, dim=-1)
        masked_classification_weights = proto_presence[:, 1].unsqueeze(0) * classification_weights
        all_protos_masked = False
        for class_idx in range(masked_classification_weights.shape[0]):
            total_relevant_protos += (classification_weights[class_idx, :] > 1e-3).sum().item()
            total_good_protos += (masked_classification_weights[class_idx, :] > 1e-3).sum().item()

print('Total protos:', total_relevant_protos, 'Total good protos:', total_good_protos, 'Ratio:', total_good_protos/total_relevant_protos)