In [1]:
# conda create -n trident python=3.10
# cd trident
# pip install -e .
# pip install ipynbname
# pip install scikit-learn

In [2]:
import ipynbname
import os

this_notebook_name = ipynbname.name()
feats_save_dir = "feats_h5"

if not os.path.exists(feats_save_dir):
    os.mkdir(feats_save_dir)

import json
import h5py
import logging

from PIL import Image, ImageFile, PngImagePlugin
Image.MAX_IMAGE_PIXELS = None 
PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024  # 100MB
PngImagePlugin.MAX_TEXT_MEMORY = 100 * 1024 * 1024 # 100MB
ImageFile.LOAD_TRUNCATED_IMAGES = True

import random
from pathlib import Path

from os.path import join as pjoin

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import torch.multiprocessing

from tqdm import tqdm

import cv2
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

import timm
from trident.patch_encoder_models import encoder_factory

from plot import draw_attention_mask
from helpers import CustomDataset, custom_extract_patch_features_from_dataloader

dataDir = Path("/Z/cuhk_data/HPACG_dataHPACG_split")
train_positive_dir = dataDir / "train/positive"
train_negative_dir = dataDir / "train/negative"

test_positive_dir = dataDir / "test/positive"
test_negative_dir = dataDir / "test/negative"

def get_random_img_path():
    rand_img = random.choice(os.listdir(dataDir / "train/positive"))
    rand_img_full_path = dataDir / "train/positive" / rand_img

    return rand_img_full_path

if not os.path.exists(dataDir):
    raise Exception("dataDir not exists")

In [3]:
def get_model():
    model = encoder_factory(model_name='conch_v15')

    def get_eval_transforms_conchv1_5(img_resize: int = 448):
        transform = transforms.Compose(
            [
                transforms.Resize(
                    img_resize, interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.CenterCrop(img_resize),
                transforms.Lambda(
                    lambda img: img.convert("RGB") if img.mode != "RGB" else img
                ),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )
        return transform

    transform = get_eval_transforms_conchv1_5()

    return model, transform

In [None]:
from os.path import join as j_
from UNI.uni.downstream.extract_patch_features import extract_patch_features_from_dataloader
from UNI.uni.downstream.eval_patch_features.linear_probe import eval_linear_probe

model, trnsfrms_val = get_model()

# if torch.cuda.device_count() > 1:
#   print("Let's use", torch.cuda.device_count(), "GPUs!")
#   model = nn.DataParallel(model)
model.to('cpu')
model.eval()

train_dataset = CustomDataset(dataDir / 'train', transform=trnsfrms_val)
test_dataset = CustomDataset(dataDir / 'test', transform=trnsfrms_val)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=False, num_workers=8)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=8)



In [None]:
train_features = custom_extract_patch_features_from_dataloader(model, train_dataloader)
test_features = custom_extract_patch_features_from_dataloader(model, test_dataloader)

In [None]:
train_feats = torch.Tensor(train_features['embeddings'])
train_labels = torch.Tensor(train_features['labels']).type(torch.long)
test_feats = torch.Tensor(test_features['embeddings'])
test_labels = torch.Tensor(test_features['labels']).type(torch.long)

In [None]:
from UNI.uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics

linprobe_eval_metrics, linprobe_dump = eval_linear_probe(
    train_feats = train_feats,
    train_labels = train_labels,
    valid_feats = None ,
    valid_labels = None,
    test_feats = test_feats,
    test_labels = test_labels,
    max_iter = 1000,
    verbose= True,
)

print_metrics(linprobe_eval_metrics)

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

y_true = linprobe_dump['preds_all']
y_pred = linprobe_dump['targets_all']
cm = confusion_matrix(y_true, y_pred)

labels = ['non-hpacg', 'hpacg']
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap='Blues')

TN, FP, FN, TP = confusion_matrix=cm.ravel()

tpr = TP / (TP + FN) if (TP + FN) > 0 else 0
fpr = FP / (FP + TN) if (FP + TN) > 0 else 0
specificity = TN / (TN + FP) if (FP + TN) > 0 else 0
Balanced_Accuracy = (specificity + tpr) / 2
Precision = TP  / (TP + FP) if (TP + FP) > 0 else 0
Weight_F1 = ((Precision * tpr) / (Precision + tpr))*2

print(f"True Positive Rate (Sensitivity,recall): [{tpr:.3f}]")
print(f"False Positive Rate : [{fpr:.3f}]")
print(f"Specificity : [{specificity:.3f}]")
print(f"Balanced Accuracy : [{Balanced_Accuracy:.3f}]")
print(f"Precision: [{Precision:.3f}]")

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

true_labels = linprobe_dump['targets_all']
pred_probs = linprobe_dump['probs_all']

fpr, tpr, thresholds = roc_curve(true_labels, pred_probs)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='blue', lw=2, label=f'UNI (AUC={roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='red', linestyle='--')  
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.grid()
plt.show()