In [None]:
import sys
import os
import torch.nn as nn
import torch
import sys, os
import random
import numpy as np
from shutil import copy
import matplotlib.pyplot as plt
from copy import deepcopy
from omegaconf import OmegaConf
import shutil
import pickle
import random
from tqdm import tqdm
import numpy as np
from torchvision.datasets.folder import ImageFolder
from torch.utils.data import DataLoader
import torch.nn.functional as F
# from skimage.filters import threshold_local, gaussian
import ntpath
from util.data import ModifiedLabelLoader
from collections import defaultdict
import heapq
from util.vis_pipnet import get_img_coordinates
import torchvision.transforms as transforms
from PIL import ImageFont, Image, ImageDraw as D
import torchvision
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
import math

from hcompnet.model import HComPNet, get_network
from util.log import Log
from util.args import get_args, save_args, get_optimizer_nn
from util.data import get_dataloaders
from util.func import init_weights_xavier
# from pipnet.test import eval_pipnet, get_thresholds, eval_ood
from util.node import Node
from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree
from util.func import get_patch_size
from util.data import ModifiedLabelLoader

In [None]:
# Set path to the experiment
run_path = 'runs/hcompnet_cub190_cnext26'

In [None]:

def minmaxscale(tensor):
    return (tensor - tensor.min()) / (tensor.max() - tensor.min())

from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data import DataLoader

def unshuffle_dataloader(dataloader):
    if type(dataloader.dataset) == ImageFolder:
        dataset = dataloader.dataset
    else:
        dataset = dataloader.dataset.dataset.dataset
    new_dataloader = DataLoader(
        dataset=dataset,
        batch_size=dataloader.batch_size,
        shuffle=False,
        num_workers=dataloader.num_workers,
        pin_memory=dataloader.pin_memory,
        drop_last=dataloader.drop_last,
        timeout=dataloader.timeout,
        worker_init_fn=dataloader.worker_init_fn,
        multiprocessing_context=dataloader.multiprocessing_context,
        generator=dataloader.generator,
        prefetch_factor=dataloader.prefetch_factor,
        persistent_workers=dataloader.persistent_workers
    )
    return new_dataloader


def get_heatmap(latent_activation, input_image, constant_color_scale=False):
    image_a = latent_activation.cpu().numpy()
    image_b = input_image.permute(1, 2, 0).cpu().numpy()
    reshaped_image_a = np.array(Image.fromarray((image_a[0] * 255).astype('uint8')).resize((input_image.shape[-1], input_image.shape[-1])))
    if constant_color_scale:
        reshaped_image_a = np.concatenate((reshaped_image_a, np.zeros((reshaped_image_a.shape[1], 1)), np.ones((reshaped_image_a.shape[1], 1))*255), axis=1)
    normalized_heatmap = (reshaped_image_a - np.min(reshaped_image_a)) / (np.max(reshaped_image_a) - np.min(reshaped_image_a))
    heatmap_colormap = plt.get_cmap('jet')
    heatmap_colored = heatmap_colormap(normalized_heatmap)
    if constant_color_scale:
        heatmap_colored = heatmap_colored[:, :-2]
    heatmap_colored_uint8 = (heatmap_colored[:, :, :3] * 255).astype(np.uint8)
    image_a_heatmap_pillow = Image.fromarray(heatmap_colored_uint8)
    image_b_pillow = Image.fromarray((image_b * 255).astype('uint8'))
    result_image = Image.blend(image_b_pillow, image_a_heatmap_pillow, alpha=0.3)
    return np.array(result_image)


def get_heap():
    list_ = []
    heapq.heapify(list_)
    return list_

## Load Model

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    device_ids = [torch.cuda.current_device()]
else:
    device = torch.device('cpu')
    device_ids = []
args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')
args = pickle.load(args_file)

# ------------ Load the phylogeny tree ------------
phylo_config = OmegaConf.load(args.phylo_config)
if phylo_config.phyloDistances_string == 'None':
    root = construct_phylo_tree(phylo_config.phylogeny_path)
    print('-'*25 + ' No discretization ' + '-'*25)
else:
    root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)
    print('-'*25 + ' Discretized ' + '-'*25)
root.assign_all_descendents()
for node in root.nodes_with_children():
    node.set_num_protos(num_protos_per_descendant=args.num_protos_per_descendant,\
                        num_protos_per_child=args.num_protos_per_child,\
                        min_protos_per_child=args.min_protos_per_child)
    
# ------------ Load the train and test datasets ------------
args.batch_size = 1
trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device)

# ------------ Load the model checkpoint ------------
ckpt_file_name = 'net_trained_last'
epoch = ckpt_file_name.split('_')[-1]
ckpt_path = os.path.join(run_path, 'checkpoints', ckpt_file_name)
checkpoint = torch.load(ckpt_path, map_location=device)
feature_net, add_on_layers, pool_layer, classification_layers = get_network(args, root=root)
net = HComPNet(feature_net = feature_net,
                args = args,
                add_on_layers = add_on_layers,
                pool_layer = pool_layer,
                classification_layers = classification_layers,
                num_parent_nodes = len(root.nodes_with_children()),
                root = root)
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)    
net.load_state_dict(checkpoint['model_state_dict'],strict=True)
# Forward one batch through the backbone to get the latent output size
with torch.no_grad():
    xs1, _, _ = next(iter(trainloader))
    xs1 = xs1.to(device)
    _, proto_features, _, _ = net(xs1)
    wshape = proto_features['root'].shape[-1]
    args.wshape = wshape #needed for calculating image patch size
    print("Output shape: ", proto_features['root'].shape, flush=True)
print(args.wshape)

# Find subtree root - only for finding does not affect the run, use the value found here in the visualization block

In [None]:
# if only interested in a subtree, add some leaf nodes of the sub tree to the list
# here we find the root of the smallest subtree that contains all the leaves

leaf_descendents = set(['cub_052_Pied_billed_Grebe', 'cub_004_Groove_billed_Ani'])
subtree_root = root
for node in root.nodes_with_children():
    if leaf_descendents.issubset(node.leaf_descendents) and (len(node.leaf_descendents) < len(subtree_root.leaf_descendents)):
        subtree_root = node
print(subtree_root.name)

# Save TopK Visualizations

In [None]:
# Proto activations on leaf descendents - topk images
vizloader_name = 'testloader' # projectloader
find_non_descendants = False # if True gets topk images from contrasting set (images that are not descendants of a given node)
topk = 3
save_images = True # True, False
font = ImageFont.truetype("assets/arial.ttf", 50)
subtree_root = root.get_node('024+051') # Plotting for a subtree since the whole tree is too large
    
patchsize, skip = get_patch_size(args)
vizloader_dict = {'trainloader': trainloader,
                 'projectloader': projectloader,
                 'testloader': testloader,
                 'test_projectloader': test_projectloader}
vizloader_dict[vizloader_name] = unshuffle_dataloader(vizloader_dict[vizloader_name])

suffix = 'contrasting_set' if find_non_descendants else ''
viz_save_dir = os.path.join(run_path, f'topk_viz_{suffix}_root={subtree_root.name}')

if type(vizloader_dict[vizloader_name].dataset) == ImageFolder:
    name2label = vizloader_dict[vizloader_name].dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
else:
    name2label = vizloader_dict[vizloader_name].dataset.dataset.dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    
for node in root.nodes_with_children():
    if (node.name not in subtree_root.descendents) and (node.name != subtree_root.name):
        print('Skipping node', node.name)
        continue

    name2label = vizloader_dict[vizloader_name].dataset.class_to_idx
    label2name = {label:name for name, label in name2label.items()}
    modifiedLabelLoader = ModifiedLabelLoader(vizloader_dict[vizloader_name], node)
    coarse_label2name = modifiedLabelLoader.modifiedlabel2name
    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}
    imgs = modifiedLabelLoader.filtered_imgs
    img_iter = tqdm(enumerate(modifiedLabelLoader),
                    total=len(modifiedLabelLoader),
                    mininterval=50.,
                    desc='Collecting topk',
                    ncols=0)

    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight
    
    # maps proto_number -> grand_child_name (or descendant leaf name) -> list of top-k activations
    proto_mean_activations = defaultdict(lambda: defaultdict(get_heap))

    # maps class names to the prototypes that belong to that
    class_and_prototypes = defaultdict(set)

    for i, (xs, orig_y, ys) in img_iter:
        xs, ys = xs.to(device), ys.to(device)

        with torch.no_grad():
            _, softmaxes, pooled, _ = net(xs, inference=False)
            pooled = pooled[node.name].squeeze(0) 
            softmaxes = softmaxes[node.name

            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)
                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)
                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)
                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]
                
                # Take the max per prototype.                             
                max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0)
                max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1)
                max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes)
                
                h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]]
                w_idx = max_idx_per_prototype_w[p]

                if len(relevant_proto_class_names) == 0:
                    continue
                
                h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx)
                latent_activation = softmaxes[:, p, :, :]

                if (not find_non_descendants and ((coarse_label2name[ys.item()] in relevant_proto_class_names))) or \
                    (find_non_descendants and ((coarse_label2name[ys.item()] not in relevant_proto_class_names)))
                        child_node = root.get_node(coarse_label2name[ys.item()])
                        leaf_descendent = label2name[orig_y.item()]#[4:7]
                        img_to_open = imgs[i][0] # it is a tuple of (path to image, lable)
                        if topk and (len(proto_mean_activations[p][leaf_descendent]) >= topk):
                            heapq.heappushpop(proto_mean_activations[p][leaf_descendent],\
                                              (pooled[p].item(), img_to_open,\
                                               (h_coor_min, h_coor_max, w_coor_min, w_coor_max), latent_activation))
                        else:
                            heapq.heappush(proto_mean_activations[p][leaf_descendent],\
                                           (pooled[p].item(), img_to_open,\
                                            (h_coor_min, h_coor_max, w_coor_min, w_coor_max), latent_activation))

                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)

    print('Node', node.name)
    for child_classname in class_and_prototypes:
        print('\t'*1, 'Child:', child_classname)
        for p in class_and_prototypes[child_classname]:
            logstr = '\t'*2 + f'Proto:{p} '
            mean_activation_of_every_leaf = []
            for leaf_descendent in proto_mean_activations[p]:
                mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)
                num_images = len(proto_mean_activations[p][leaf_descendent])
                logstr += f'{leaf_descendent}:({mean_activation}) '
                mean_activation_of_every_leaf.append(mean_activation)
            print(logstr)
            
            if len(proto_mean_activations[p]) == 0:
                continue
            
            if save_images:
                patches = []
                right_descriptions = []
                text_region_width = 3 # 3x the width of a patch

                font_size = 40
                fnt = ImageFont.truetype("assets/arial.ttf", font_size)
                max_width = ImageDraw.Draw(Image.new("RGB", (100, 100), (255, 0, 0))).textlength('-', font=fnt)
                for leaf_descendent in proto_mean_activations[p]:
                    for word in leaf_descendent.split('_')[2:]:
                        width_of_word = ImageDraw.Draw(Image.new("RGB", (100, 100), (255, 0, 0))).textlength(word, font=fnt)
                        max_width = max(max_width, width_of_word)

                for leaf_descendent, heap in proto_mean_activations[p].items():
                    if 'BUT' in args.dataset:
                        species_name = ' '.join(leaf_descendent.split('_')[2:4])
                    else:
                        species_name = ' '.join(leaf_descendent.split('_')[2:])
                    heap = sorted(heap)[::-1]
                    mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)
                    for rank, ele in enumerate(heap):
                        activation, img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max), latent_activation = ele
                        image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img_to_open))
                        img_tensor = transforms.ToTensor()(image)#.unsqueeze_(0) #shape (1, 3, h, w)
                        overlayed_image_np = get_heatmap(latent_activation, img_tensor, constant_color_scale=True)
                        overlayed_image = torch.tensor(overlayed_image_np).permute(2, 0, 1).float() / 255.
                        
                        reshaped_latent_activation = np.array(Image.fromarray((latent_activation.cpu().numpy()[0] * 255).astype('uint8')).resize((img_tensor.shape[-1], img_tensor.shape[-1])))
                        center = np.unravel_index(np.argmax(reshaped_latent_activation), reshaped_latent_activation.shape)
                        patch_size = 64
                        h_coor_min = int(max(0, center[0] - (patch_size/2.)))
                        h_coor_max = int(min(img_tensor.shape[1], center[0] + (patch_size/2.)))
                        w_coor_min = int(max(0, center[1] - (patch_size/2.)))
                        w_coor_max = int(min(img_tensor.shape[2], center[1] + (patch_size/2.)))
                        img_tensor_patch = img_tensor[:, h_coor_min:h_coor_max, w_coor_min:w_coor_max]

                        scale_factor = 1.7  # 70% increase
                        heatmap_patch = overlayed_image[:, h_coor_min:h_coor_max, w_coor_min:w_coor_max]
                        resized_heatmap_patch = F.interpolate(heatmap_patch.unsqueeze(0), scale_factor=scale_factor, \
                                                      mode='bilinear', align_corners=False).squeeze(0)
                        resized_heatmap_patch = torchvision.utils.draw_bounding_boxes((resized_heatmap_patch * 255).to(torch.uint8), \
                                                                                torch.tensor([[0, 0, resized_heatmap_patch.shape[2], resized_heatmap_patch.shape[1]]]), \
                                                                                width=4, colors=(255, 0, 0))
                        resized_heatmap_patch = resized_heatmap_patch.float() / 255.
                        
                        resized_img_patch = F.interpolate(img_tensor_patch.unsqueeze(0), scale_factor=scale_factor, \
                                                      mode='bilinear', align_corners=False).squeeze(0)
                        resized_img_patch = torchvision.utils.draw_bounding_boxes((resized_img_patch * 255).to(torch.uint8), \
                                                                                torch.tensor([[0, 0, resized_img_patch.shape[2], resized_img_patch.shape[1]]]), \
                                                                                width=4, colors=(255, 255, 0))
                        resized_img_patch = resized_img_patch.float() / 255.
                        
                        resized_patch = torchvision.utils.make_grid([resized_img_patch, resized_heatmap_patch], nrow=1, padding=1, pad_value=1., border=1)
                        white_image = torch.ones(3, img_tensor.shape[1], img_tensor.shape[2])
                        patch_height = resized_patch.shape[1]
                        y_start = (white_image.shape[1] - patch_height) // 2                        
                        x_start = 10  # 10 pixels from the left
                        white_image[:, y_start:y_start+patch_height, x_start:x_start+resized_patch.shape[2]] = resized_patch

                        # Bounding box on original image
                        img_tensor = torchvision.utils.draw_bounding_boxes((img_tensor * 255).to(torch.uint8), \
                                                                                torch.tensor([[w_coor_min, h_coor_min, w_coor_max, h_coor_max]]), \
                                                                                width=2, colors=(255, 255, 0))
                        img_tensor = img_tensor.float() / 255.

                        # Bounding box on overlayed image
                        overlayed_image = torchvision.utils.draw_bounding_boxes((overlayed_image * 255).to(torch.uint8), \
                                                                                torch.tensor([[w_coor_min, h_coor_min, w_coor_max, h_coor_max]]), \
                                                                                width=2, colors=(255, 0, 0))
                        overlayed_image = overlayed_image.float() / 255.

                        grid_cell = torchvision.utils.make_grid([overlayed_image, img_tensor, white_image], nrow=3, padding=5, pad_value=1., border=1)

                        patches.append(grid_cell)

                    text = '\n'.join(species_name.split(' '))
                    image_size = (math.ceil(max_width) + 10, patches[0].shape[1])
                    txtimage = Image.new("RGB", image_size, (255, 255, 255))
                    d = ImageDraw.Draw(txtimage)
                    d.multiline_text((image_size[0]/2, image_size[1]/2), text, font=fnt, fill=(0, 0, 0), align ="center", anchor="mm")
                    txttensor = transforms.ToTensor()(txtimage)#.unsqueeze_(0)
                    right_descriptions.append(txttensor)
                    

                padding = 0
                grid_rows = []
                for k in range(len(proto_mean_activations[p])):
                    grid_row = torchvision.utils.make_grid(patches[k*topk:(k+1)*topk], nrow=topk, padding=padding, border=0)
                    grid_right_description = torchvision.utils.make_grid(right_descriptions[k], nrow=1, padding=padding, border=0)
                    grid_row = torch.cat([grid_right_description, grid_row], dim=-1)
                    grid_rows.append(grid_row)
                grid = torchvision.utils.make_grid(grid_rows, nrow=1, padding=5, pad_value=1.)
                    
                if save_images:
                    os.makedirs(os.path.join(viz_save_dir, node.name), exist_ok=True)
                    torchvision.utils.save_image(grid, os.path.join(viz_save_dir, node.name, f'{child_classname}-p{p}.png'), border=0)

print('Done !!!')