In [None]:
# conda create -n trident python=3.10
# cd trident
# pip install -e .
# pip install ipynbname
# pip install scikit-learn

In [1]:
import ipynbname
import os

this_notebook_name = ipynbname.name()
feats_save_dir = "feats_h5"

if not os.path.exists(feats_save_dir):
    os.mkdir(feats_save_dir)

import json
import h5py
import logging

from PIL import Image, ImageFile, PngImagePlugin
Image.MAX_IMAGE_PIXELS = None 
PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024  # 100MB
PngImagePlugin.MAX_TEXT_MEMORY = 100 * 1024 * 1024 # 100MB
ImageFile.LOAD_TRUNCATED_IMAGES = True

import random
from pathlib import Path

from os.path import join as pjoin

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import torch.multiprocessing

from tqdm import tqdm

import cv2
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

import timm
from trident.patch_encoder_models import encoder_factory

dataDir = Path("/media/toom/New Volume/cuhk_data/HPACG_dataHPACG_split")
train_positive_dir = dataDir / "train/positive"
train_negative_dir = dataDir / "train/negative"

test_positive_dir = dataDir / "test/positive"
test_negative_dir = dataDir / "test/negative"

def get_random_img_path():
    rand_img = random.choice(os.listdir(dataDir / "train/positive"))
    rand_img_full_path = dataDir / "train/positive" / rand_img

    return rand_img_full_path

In [2]:
def get_model():
    model = encoder_factory(model_name='conch_v15')

    def get_eval_transforms_conchv1_5(img_resize: int = 448):
        transform = transforms.Compose(
            [
                transforms.Resize(
                    img_resize, interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.CenterCrop(img_resize),
                transforms.Lambda(
                    lambda img: img.convert("RGB") if img.mode != "RGB" else img
                ),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )
        return transform

    transform = get_eval_transforms_conchv1_5()

    return model, transform

def run_model_and_generate_attention_mask(img_path):
    model, transform = get_model()
    model.to('cpu')
    _ = model.eval()

    img = Image.open(img_path).convert("RGB")
    o_img = img.copy()
    img = transform(img).unsqueeze(dim=0).to('cpu')

    attention_scores = []
    hooks = []

    def get_attention_matrix(module, input, output):
        attention_scores.append(output.detach().cpu())

    for i in range(24):
        target_layer = model.model.trunk.blocks[i].attn
        hook = target_layer.register_forward_hook(get_attention_matrix)
        hooks.append(hook)

    # block = model.model.trunk.blocks[-1]
    # def block_hook(module, input, output):
    #     print("Block executed!")
    #     print(f"Input shape: {input[0].shape}")
    # handle = block.register_forward_hook(block_hook)

    _ = model(img)

    for hook in hooks:
        hook.remove()

    o_img = np.array(o_img)
    attentions = attention_scores[-1][:, 1:, :]
    heatmap = torch.norm(attentions, dim=-1)
    heatmap = heatmap.reshape(28, 28).detach().numpy()
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
    heatmap = heatmap ** 4
    # heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

    smoothed_heatmap = scipy.ndimage.gaussian_filter(heatmap, sigma=1.5)
    smoothed_heatmap = cv2.resize(smoothed_heatmap, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
    # smoothed_heatmap[smoothed_heatmap < smoothed_heatmap.mean()] *= 0.7
    H, W = smoothed_heatmap.shape

    overlay = np.zeros((H, W, 4))

    overlay[..., 0] = 1.0  # Red
    overlay[..., 1] = 0.0  # Green
    overlay[..., 2] = 0.0  # Blue
    overlay[..., 3] = smoothed_heatmap 

    def plot_side_by_side(img, overlay):
        # Create a figure with 1 row and 2 columns
        fig, axes = plt.subplots(1, 3, figsize=(12, 6))

        # Plot first image
        axes[0].imshow(img)
        axes[0].set_title("Input")
        axes[0].axis('off') # Hide tick numbers

        # Plot second image
        axes[1].imshow(img)
        axes[1].imshow(overlay)
        axes[1].set_title("Overlay")
        axes[1].axis('off')

        axes[2].imshow(overlay)
        axes[2].set_title("Attention Mask")
        axes[2].axis('off')

        plt.tight_layout() # Adjusts spacing so they don't overlap
        plt.show()

    # Usage
    plot_side_by_side(o_img, overlay)

In [None]:
import torchvision.datasets as datasets

class CustomDataset(datasets.ImageFolder):
    """Custom dataset that includes image file paths."""
    
    def __getitem__(self, index):
        original_tuple = super().__getitem__(index)
        image = original_tuple[0]
        label = original_tuple[1]
        
        # Get the file path of the image
        img_path = self.samples[index][0]
        
        return (image, label, os.path.basename(img_path))


@torch.no_grad()
def custom_extract_patch_features_from_dataloader(model, dataloader, save_dir=None):
    """ Modified from uni.downstream.extract_patch_features.extract_patch_features_from_dataloader
        Uses model to extract features+labels from images iterated over the dataloader.

    Args:
        model (torch.nn): torch.nn CNN/VIT architecture with pretrained weights that extracts d-dim features.
        dataloader (torch.utils.data.DataLoader): torch.utils.data.DataLoader object of N images.

    Returns:
        dict: Dictionary object that contains (1) [N x D]-dim np.array of feature embeddings, and (2) [N x 1]-dim np.array of labels

    """
    torch.multiprocessing.set_sharing_strategy("file_system")

    # all_embeddings, all_labels, all_filenames = [], [], []
    batch_size = dataloader.batch_size
    try:
        device = next(model.parameters())[0].device
    except:
        device = next(model.parameters()).device

    if save_dir is None:
        h5_file_path = os.path.join(feats_save_dir, this_notebook_name+'.h5')
    else:
        h5_file_path = save_dir

    with h5py.File(h5_file_path, 'a') as hf:
        for batch_idx, (batch, target, filenames) in tqdm(
            enumerate(dataloader), total=len(dataloader)
        ):
            if filenames[0] in hf and filenames[-1] in hf:
                continue

            remaining = batch.shape[0]
            if remaining != batch_size:
                _ = torch.zeros((batch_size - remaining,) + batch.shape[1:]).type(
                    batch.type()
                )
                batch = torch.vstack([batch, _])

            batch = batch.to(device)
            with torch.inference_mode():
                embeddings = model(batch).detach().cpu()[:remaining, :]
                labels = target.numpy()[:remaining]
                assert not torch.isnan(embeddings).any()

            for i in range(len(filenames)):
                if filenames[i] in hf:
                    continue
                dset = hf.create_dataset(filenames[i], data=embeddings[i].numpy())
                dset.attrs["label"] = labels[i]

            # all_embeddings.append(embeddings)
            # all_labels.append(labels)
            # all_filenames.append(filename)

    return None
    # asset_dict = {
    #     "embeddings": np.vstack(all_embeddings).astype(np.float32),
    #     "labels": np.concatenate(all_labels),
    # }

    # return asset_dict

In [None]:
from os.path import join as j_
from UNI.uni.downstream.extract_patch_features import extract_patch_features_from_dataloader
from UNI.uni.downstream.eval_patch_features.linear_probe import eval_linear_probe

model, trnsfrms_val = get_model()
model.to('cpu')
model.eval()

train_dataset = CustomDataset(dataDir / 'train', transform=trnsfrms_val)
test_dataset = CustomDataset(dataDir / 'test', transform=trnsfrms_val)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=False, num_workers=16)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=16)



In [7]:
# custom_extract_patch_features_from_dataloader(model, train_dataloader)

In [None]:
train_features = extract_patch_features_from_dataloader(model, train_dataloader)
test_features = extract_patch_features_from_dataloader(model, test_dataloader)

In [None]:
train_feats = torch.Tensor(train_features['embeddings'])
train_labels = torch.Tensor(train_features['labels']).type(torch.long)
test_feats = torch.Tensor(test_features['embeddings'])
test_labels = torch.Tensor(test_features['labels']).type(torch.long)

In [None]:
from UNI.uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics

linprobe_eval_metrics, linprobe_dump = eval_linear_probe(
    train_feats = train_feats,
    train_labels = train_labels,
    valid_feats = None ,
    valid_labels = None,
    test_feats = test_feats,
    test_labels = test_labels,
    max_iter = 1000,
    verbose= True,
)

print_metrics(linprobe_eval_metrics)

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

y_true = linprobe_dump['preds_all']
y_pred = linprobe_dump['targets_all']
cm = confusion_matrix(y_true, y_pred)

labels = ['non-hpacg', 'hpacg']
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap='Blues')

TN, FP, FN, TP = confusion_matrix=cm.ravel()

tpr = TP / (TP + FN) if (TP + FN) > 0 else 0
fpr = FP / (FP + TN) if (FP + TN) > 0 else 0
specificity = TN / (TN + FP) if (FP + TN) > 0 else 0
Balanced_Accuracy = (specificity + tpr) / 2
Precision = TP  / (TP + FP) if (TP + FP) > 0 else 0
Weight_F1 = ((Precision * tpr) / (Precision + tpr))*2

print(f"True Positive Rate (Sensitivity,recall): [{tpr:.3f}]")
print(f"False Positive Rate : [{fpr:.3f}]")
print(f"Specificity : [{specificity:.3f}]")
print(f"Balanced Accuracy : [{Balanced_Accuracy:.3f}]")
print(f"Precision: [{Precision:.3f}]")

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

true_labels = linprobe_dump['targets_all']
pred_probs = linprobe_dump['probs_all']

fpr, tpr, thresholds = roc_curve(true_labels, pred_probs)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='blue', lw=2, label=f'UNI (AUC={roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='red', linestyle='--')  
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.grid()
plt.show()