In [None]:
import torch
import pandas as pd
import os
import sys
sys.path.append('protopnext_userstudygui/protopnext')
sys.path.append('protopnext_userstudygui/protopnext/protopnet')
sys.path.append('explain_dataset/protopnext/protopnet/utilities')
from pathlib import Path
import matplotlib.pyplot as plt
from general_utilities import find_high_activation_crop
from project_utilities import custom_unravel_index, hash_func
from visualization_utilities import indices_to_upsampled_boxes
import numpy as np
from protopnet.visualization import *

# from tqdm import tqdm
from tqdm.auto import tqdm
import itertools
import random
import multiprocessing as mp
import multiprocessing
from functools import partial
from torchvision import transforms

import time

In [None]:
def check_done(model_path):
    model_id = '_'.join(model_path.split('/')[-2:])[:-4]
    save_folder = Path(f'protopnext_userstudygui/image_folderv5/{model_id}')    
    return os.path.exists(str(save_folder / 'all_boxes.npy'))

In [None]:
from protopnet.datasets import cub200, cub200_cropped

batch_size = 20

split_dataloader = cub200_cropped.train_dataloaders(
    batch_sizes={"train": batch_size, "project": batch_size, "val": batch_size},
    data_path = os.environ.get("CUB200_DIR"), 
)

from protopnet.preprocess import mean, std
from protopnet.datasets.torch_extensions import ImageFolderDict

push_dataset_path = pathlib.Path(
    os.environ.get("CUB200_DIR") + "/train_cropped"
)
img_size = 224
normalize = transforms.Normalize(mean=mean, std=std)
push_dataset = ImageFolderDict(
    push_dataset_path,
    transforms.Compose(
        [
            transforms.Resize(size=(img_size, img_size)),
            transforms.ToTensor(),
            # normalize,
        ]
    ),
)
loader_config = {
    "batch_size": batch_size,
    "shuffle": False,
    "num_workers": 29,
    "pin_memory": False,
    "prefetch_factor": 8,
}

s_dataset = push_dataset  # Subset(push_dataset, range(10))
push_loader = torch.utils.data.DataLoader(s_dataset, **loader_config)


val_dataset_path = pathlib.Path(
    os.environ.get("CUB200_DIR") + "/val_cropped"
)
img_size = 224
normalize = transforms.Normalize(mean=mean, std=std)
val_dataset = ImageFolderDict(
    val_dataset_path,
    transforms.Compose(
        [
            transforms.Resize(size=(img_size, img_size)),
            transforms.ToTensor(),
            # normalize,
        ]
    ),
)
loader_config = {
    "batch_size": batch_size,
    "shuffle": False,
    "num_workers": 29,
    "pin_memory": False,
    "prefetch_factor": 8,
}
val_no_aug_loader = torch.utils.data.DataLoader(val_dataset, **loader_config)


val_loader = split_dataloader.val_loader
push_loader = split_dataloader.project_loader

In [None]:
cropped_models_infos = pd.read_csv("model-pairs.csv")

In [None]:
for idx in range(len(cropped_models_infos)):
    print(cropped_models_infos['best[prototypes_embedded]/eval/accuracy_cos'][idx], \
          cropped_models_infos['best[prototypes_embedded]/eval/accuracy_l2'][idx])

In [None]:
df = cropped_models_infos.copy()

df = df.sort_values(by=["backbone", "best[prototypes_embedded]/eval/accuracy_cos"], ascending=[True, False])
df['group'] = df.groupby('backbone').cumcount()
cropped_models_infos = df.sort_values(by="group").drop(columns=["group"]).reset_index(drop=True)

In [None]:
all_model_paths = []
for idx in range(len(cropped_models_infos)):
    print(cropped_models_infos['backbone'][idx], cropped_models_infos['best[prototypes_embedded]/eval/accuracy_cos'][idx], \
          cropped_models_infos['best[prototypes_embedded]/eval/accuracy_l2'][idx])
    all_model_paths.append([cropped_models_infos['best_model_cos'].values[idx], idx, cropped_models_infos['backbone'].values[idx], 'cos'])
    all_model_paths.append([cropped_models_infos['best_model_l2'].values[idx], idx, cropped_models_infos['backbone'].values[idx], 'l2'])

In [None]:
len(all_model_paths)

In [None]:
def sample_indices_by_percentile(array, num_samples):
    np.random.seed(42)
    ret = dict()
    for i in range(10):
        low = np.percentile(array, i * 10)
        high = np.percentile(array, (i + 1) * 10)
        indices = np.where((array >= low) & (array < high))[0]
        if len(indices) > 0:
            sampled_indices = np.random.choice(indices, min(num_samples, len(indices)), replace=False)
            ret[i] = sampled_indices
        else:
            ret[i] = np.array([])

    return ret

def create_lookup_table(index_dict):
    lookup = {}
    for key, indices in index_dict.items():
        for idx in indices:
            lookup[idx] = key
    return lookup

def percentile(tensor: torch.Tensor, value: float):
    sorted_tensor = torch.sort(tensor).values  # Sort the tensor
    rank = torch.searchsorted(sorted_tensor, torch.tensor(value), right=True)  # Find the position
    percentile = (rank / len(tensor)) * 100  # Convert rank to percentile
    return percentile.item()


In [None]:
def get_bounding_box(heatmap, percentile=95):
    """
    Finds the bounding box (min_x, min_y, max_x, max_y) of pixels with values >= the given percentile.

    Parameters:
    - heatmap (np.ndarray): 2D numpy array representing the heatmap.
    - percentile (float): Percentile threshold (default: 95).

    Returns:
    - tuple: (min_x, min_y, max_x, max_y) coordinates of the bounding box.
             Returns None if no pixels exceed the threshold.
    """
    
    # Compute the threshold value for the given percentile
    threshold = np.percentile(heatmap, percentile)

    # Get coordinates of pixels that meet the threshold
    y_indices, x_indices = np.where(heatmap >= threshold)
    
    if len(x_indices) == 0 or len(y_indices) == 0:
        print('Issue found here', flush=True)
        return [-1, -1, -1, -1]  # No pixels exceed the threshold

    # Bounding box coordinates
    min_x, max_x = np.min(x_indices), np.max(x_indices)
    min_y, max_y = np.min(y_indices), np.max(y_indices)

    return min_x, min_y, max_x, max_y

def get_all_img_boxes(
    model,
    img_idx,
    ori_path,
    target_proto_idx,
    eval_dataloader,
    act_percentile,
    act_val,
    eval_dataloader_no_aug,
    img_size,
    device,
    all_latent_space_size_activation_maps,
    proto_rank=-1
):

    latent_space_size_activation_maps = all_latent_space_size_activation_maps[img_idx]
    proto_actmaps_on_image = latent_space_size_activation_maps[
        target_proto_idx, :, :
    ].unsqueeze(0).unsqueeze(0)
    proto_actmaps_on_image_normed = (
        proto_actmaps_on_image - proto_actmaps_on_image.min()
    ) / (proto_actmaps_on_image.max() - proto_actmaps_on_image.min())
    img_height, img_width = img_size, img_size

    proto_actmaps_on_image_normed = \
        torch.nn.Upsample(
            size=(img_height, img_width), mode="bilinear", align_corners=False
        )(proto_actmaps_on_image_normed).squeeze(1).clone().detach().cpu().numpy()
    
    (
        proto_part_left_x,
        proto_part_upper_y,
        proto_part_right_x,
        proto_part_lower_y,
    ) = get_bounding_box(proto_actmaps_on_image_normed.squeeze(0), percentile=95)

    return [img_idx, ori_path, target_proto_idx, proto_rank, act_percentile, act_val, proto_part_left_x,
        proto_part_upper_y,
        proto_part_right_x,
        proto_part_lower_y,]

def process_model(model_path):
    
    model_id = '_'.join(model_path.split('/')[-2:])[:-4]
    save_folder = Path(f'protopnext_userstudygui/image_folderv5/{model_id}')
    print(f"Processing {model_id}, at {save_folder}", flush=True)
    os.makedirs(str(save_folder), exist_ok=True)
    
    if os.path.exists(str(save_folder / 'all_boxes.npy')):
        print("Skipping")
        print(os.system(f"stat {str(save_folder / 'all_boxes.npy')}"))
        return
    
    model = torch.load(model_path)
    model = model.to('cuda')
    
    model = reproject_prototypes(model, save_folder, push_loader, device='cuda')
    model.prune_prototypes()
    
    if not os.path.exists(str(save_folder / 'prototypes/patch_info_dict.json')):
    
        save_prototype_images_to_file(
            model=model,
            std=split_dataloader.std,
            mean=split_dataloader.mean,
            push_dataloader=push_loader,
            save_loc=save_folder,
            img_size=(img_size, img_size),
            device="cuda"
        )
    
    device = 'cuda'
    model = model.to(device)
    
    model.eval()
    
    with torch.no_grad():
        similarity_score_to_each_prototypes = []
        all_latent_space_size_activation_maps = []
        all_paths = []
        for batch_data_dict in tqdm(val_loader, total=len(val_loader), leave=False, desc="Extract feats"):
            batch_images = batch_data_dict["img"].to(device)
            batch_paths = batch_data_dict["path"]
            model_outputs = model(
                batch_images,
                return_prototype_layer_output_dict=True,
                return_similarity_score_to_each_prototype=True,
            )
            # batch size, num_protos, latent_height, latent_width
            latent_space_size_activation_maps = model_outputs["prototype_activations"]

            # batch size, num_protos
            similarity_score_to_each_prototype = model_outputs[
                "similarity_score_to_each_prototype"
            ]

            for p in batch_paths:
                all_paths.append(p)

            all_latent_space_size_activation_maps.append(latent_space_size_activation_maps.detach().cpu())
            similarity_score_to_each_prototypes.append(similarity_score_to_each_prototype.detach().cpu())
            
            del batch_images, model_outputs

    similarity_score_to_each_prototypes = torch.cat(similarity_score_to_each_prototypes, dim=0)
    all_latent_space_size_activation_maps = torch.cat(all_latent_space_size_activation_maps, dim=0)
    print(similarity_score_to_each_prototypes.shape, all_latent_space_size_activation_maps.shape)
    
    proto_sorted_idxs = torch.argsort(similarity_score_to_each_prototypes, dim=1, descending=True)
    all_boxes = []
    for img_idx, proto_idxs in tqdm(enumerate(proto_sorted_idxs), total=len(proto_sorted_idxs)):

        found = False
        for qc, proto_idx in enumerate(proto_idxs):
            if found:
                break
            act_val = similarity_score_to_each_prototypes[img_idx][proto_idx]
            act_perc = percentile(similarity_score_to_each_prototypes[:, proto_idx], act_val)

            if act_perc < 90:
                found = True


            all_boxes.append(get_all_img_boxes(
                model,
                img_idx=img_idx,
                ori_path=all_paths[img_idx],
                proto_rank=qc,
                target_proto_idx=proto_idx,
                eval_dataloader=val_loader,
                act_percentile=act_perc,
                act_val=act_val.item(),
                eval_dataloader_no_aug=val_no_aug_loader,
                img_size=224,
                device='cuda',
                all_latent_space_size_activation_maps=all_latent_space_size_activation_maps
            ))
    
    np.save(save_folder / 'all_boxes.npy', all_boxes)
    
    del model, proto_sorted_idxs, all_boxes, similarity_score_to_each_prototypes, all_latent_space_size_activation_maps
    torch.cuda.empty_cache()
    
for model_path, row_idx, backbone, dist_metric in tqdm(all_model_paths):
    print(model_path, row_idx, backbone, dist_metric)
    process_model(model_path)
    
    

    finished_rows_info = []
    finished_l2_only = []
    finished_cos_only = []

    for row_idx in range(len(cropped_models_infos)):

        backbone = cropped_models_infos['backbone'].values[row_idx]
        model_path_cos = cropped_models_infos['best_model_cos'].values[row_idx]
        model_path_l2 = cropped_models_infos['best_model_l2'].values[row_idx]

        done_cos = check_done(model_path_cos)
        done_l2 = check_done(model_path_l2)

        if done_cos and done_l2:
            finished_rows_info.append([row_idx, backbone, model_path_cos, model_path_l2])

        if done_cos and not done_l2: finished_cos_only.append([row_idx, backbone, model_path_cos, model_path_l2])
        if done_l2 and not done_cos: finished_l2_only.append([row_idx, backbone, model_path_cos, model_path_l2])
    
    print(len(finished_rows_info), len(finished_cos_only), len(finished_l2_only))

In [None]:
backbone_counts = dict()
for row_idx, backbone, model_path_cos, model_path_l2 in finished_rows_info:
    
    if backbone in backbone_counts:
        backbone_counts[backbone].add(row_idx)
    else:
        backbone_counts[backbone] = set([row_idx])

In [None]:
len(backbone_counts['densenet161']), len(backbone_counts['vgg19']), len(backbone_counts['resnet50[pretraining=inaturalist]'])

# process boxes

In [None]:
def compute_iou_vectorized(boxes1, boxes2, eps=1e-6):
    """
    Vectorized IoU computation for two sets of boxes in format: [x1, y1, x2, y2].
    boxes1: shape (N, 4), boxes2: shape (M, 4)
    Returns: IoU matrix of shape (N, M).
    """
    x1 = np.maximum(boxes1[:, 0][:, None], boxes2[:, 0][None, :])
    y1 = np.maximum(boxes1[:, 1][:, None], boxes2[:, 1][None, :])
    x2 = np.minimum(boxes1[:, 2][:, None], boxes2[:, 2][None, :])
    y2 = np.minimum(boxes1[:, 3][:, None], boxes2[:, 3][None, :])

    inter_width = np.maximum(0, x2 - x1)
    inter_height = np.maximum(0, y2 - y1)
    intersection = inter_width * inter_height

    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

    union = area1[:, None] + area2[None, :] - intersection
    iou = intersection / (union + eps)
    return iou

def scale_bounding_box(avg_left_x, avg_upper_y, avg_right_x, avg_lower_y, H, W):
    
    scale_x = W / 224.0
    scale_y = H / 224.0

    left_x = avg_left_x * scale_x
    upper_y = avg_upper_y * scale_y
    right_x = avg_right_x * scale_x
    lower_y = avg_lower_y * scale_y

    return int(left_x), int(upper_y), int(right_x), int(lower_y)


def plot_scaled_bounding_boxes(image, avg_left_x, avg_upper_y, avg_right_x, avg_lower_y):
    H, W, _ = image.shape

    up_left_x, up_upper_y, up_right_x, up_lower_y = scale_bounding_box(
        avg_left_x, avg_upper_y, avg_right_x, avg_lower_y, H, W
    )

    image_upscaled = image.copy()
    cv2.rectangle(image_upscaled, (up_left_x, up_upper_y), (up_right_x, up_lower_y), (255, 0, 0), 2)

    image_downscaled = cv2.resize(image_upscaled, (224, 224))

    down_left_x = int(avg_left_x)
    down_upper_y = int(avg_upper_y)
    down_right_x = int(avg_right_x)
    down_lower_y = int(avg_lower_y)
    cv2.rectangle(image_downscaled, (down_left_x, down_upper_y), (down_right_x, down_lower_y), (0, 255, 0), 2)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(cv2.cvtColor(image_upscaled, cv2.COLOR_BGR2RGB))
    axes[0].set_title(f'Upscaled Image ({H}x{W})')

    axes[1].imshow(cv2.cvtColor(image_downscaled, cv2.COLOR_BGR2RGB))
    axes[1].set_title('Downscaled Image (224x224)')

    for ax in axes:
        ax.axis('off')

    plt.show()

In [None]:
print('Start')

def process_packet(all_boxes, patch_loc):
    box_info_dict = dict()
    for x in tqdm(all_boxes):
        
        # get this image's box info
        img_idx, ref_img_path, target_proto_idx, proto_rank, act_percentile, act_val,\
        proto_part_left_x,\
        proto_part_upper_y,\
        proto_part_right_x,\
        proto_part_lower_y = x
        
        # get proto info
        proto_idx = str(target_proto_idx.split('(')[-1].split(')')[0])
        proto_part_lower_y1, proto_part_upper_y1, proto_part_right_x1, proto_part_left_x1, proto_img_path = patch_loc[proto_idx]
        proto_part_left_x,\
        proto_part_upper_y,\
        proto_part_right_x,\
        proto_part_lower_y = \
        int(proto_part_left_x),\
        int(proto_part_upper_y),\
        int(proto_part_right_x),\
        int(proto_part_lower_y)
        
        # assemble boxes
        ref_box = [proto_part_left_x,\
                        proto_part_upper_y,\
                        proto_part_right_x,\
                        proto_part_lower_y]
        
        proto_box = [proto_part_left_x1, proto_part_upper_y1, proto_part_right_x1, proto_part_lower_y1]
        
        part_left_x, part_right_x = sorted([ref_box[0], ref_box[2]])
        part_upper_y, part_lower_y = sorted([ref_box[1], ref_box[3]])

        proto_part_left_x1, proto_part_right_x1 = sorted([proto_box[0], proto_box[2]])
        proto_part_upper_y1, proto_part_lower_y1 = sorted([proto_box[1], proto_box[3]])
        
        # assemble packet
        packet = {
            'img_idx': img_idx,
            'ref_path': ori_path,
            'proto_idx': proto_idx,
            'proto_path': proto_img_path,
            'ref_box': [part_left_x, part_upper_y, part_right_x, part_lower_y],
            'proto_box': [proto_part_left_x1, proto_part_upper_y1, proto_part_right_x1, proto_part_lower_y1],
            'proto_rank': proto_rank,
            'act_percentile': act_percentile,
            'act_val': act_val,
        }
        
        if img_idx in box_info_dict:
            box_info_dict[img_idx].append(packet)
        else:
            box_info_dict[img_idx] = [packet]
            # print(img_idx, ref_img_path)
    
    return box_info_dict


all_info_dict = dict()
for row_idx in tqdm(range(len(cropped_models_infos))):
    
    backbone = cropped_models_infos['backbone'].values[row_idx]
    model_path1, model_path2 = cropped_models_infos['best_model_cos'][row_idx], cropped_models_infos['best_model_l2'][row_idx]
    model_id1 = '_'.join(model_path1.split('/')[-2:])[:-4]
    save_folder1 = Path(f'protopnext_userstudygui/image_folderv5/{model_id1}')
    model_id2 = '_'.join(model_path2.split('/')[-2:])[:-4]
    save_folder2 = Path(f'protopnext_userstudygui/image_folderv5/{model_id2}')
    print(save_folder1, save_folder2)
    
    done_cos = check_done(model_path1)
    done_l2 = check_done(model_path2)
    
    if done_cos and done_l2:
        
        all_boxes1 = np.load(save_folder1 / 'all_boxes.npy')
        all_boxes2 = np.load(save_folder2 / 'all_boxes.npy')
        
        with open (f"{save_folder1}/prototypes/patch_info_dict.json", 'r') as f:
            patch_loc1 = json.load(f)
        f.close()

        with open (f"{save_folder2}/prototypes/patch_info_dict.json", 'r') as f:
            patch_loc2 = json.load(f)
        f.close()
        ref_id = f"row{row_idx}_img_idx{img_idx}_{model_id1}_{model_id2}_{backbone}"
        print(ref_id)
        cos_packet = process_packet(all_boxes1, patch_loc1)
        
        l2_packet = process_packet(all_boxes2, patch_loc2)
        
        
        all_info_dict[ref_id] = dict()
        all_info_dict[ref_id]['cos'] = cos_packet
        all_info_dict[ref_id]['l2'] = l2_packet
        
        break
        

In [None]:
all_info_dict[list(all_info_dict.keys())[0]]['cos']['0']

In [None]:
all_info_dict[list(all_info_dict.keys())[0]]['cos']['1']

In [None]:
with open('results_v5.json', 'w') as f:
    json.dump(all_info_dict, f)