In [None]:
import os
import sys
wdir = "/projects/ovcare/users/cindy_shi/ldm/uncond-image-generation-ldm"
sys.path.append(wdir)

from src.pipeline import UncondLatentDiffusionPipeline
import torch
from torch.utils.data import DataLoader
from src.data import PathologyValidation, PathologyTest, PathologyLabels

In [2]:
model_id = "checkpoints/ddpm-model"
model_id = os.path.join(wdir, model_id)
external_set = "external_set_1"

num_sampling_steps = 50
noise_timesteps = 350
eval_batch_size = 16

: 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = PathologyValidation(size=256, debug=True)
dataloader = DataLoader(dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4, pin_memory=True)
pipe = UncondLatentDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, return_dict=True).to(device)

: 

In [None]:
pipe.vae

In [13]:
len(dataset)

5

In [14]:
# some helper functions
from collections import defaultdict
from typing_extensions import Dict

def z_score(scores: torch.Tensor) -> torch.Tensor:
    scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
    mean = scores.mean()
    std = scores.std()
    z_scores = (scores - mean) / (std + 1e-8)
    return z_scores

# patch to slides
# average Z-score of all values exceeding the 99’th percentile of the anomaly heatmap
# or max
def slide_z_scores(scores, paths, method="max", labels=None) -> Dict[str, float]:
    slide_dict = defaultdict(list)
    slide_labels = {}
    for score, path, label in zip(scores, paths, labels):
        slide_id = path.split("/")[PathologyLabels.SID_IDX]
        slide_dict[slide_id].append(score.item())
        slide_labels[slide_id] = label
    
    if method == "max":
        slide_dict = {k: max(v) for k, v in slide_dict.items()}
    elif method == "avg99":
        slide_threshold = {k: torch.tensor(v).kthvalue(int(0.99 * len(v))).values.item() for k, v in slide_dict.items()}
        slide_dict = {k: torch.tensor(v)[torch.tensor(v) >= slide_threshold[k]].mean().item() for k, v in slide_dict.items()}
    
    return slide_dict, slide_labels

def dict_to_device(batch):
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device)
    return batch

In [None]:
import lpips 
import pytorch_msssim
import torch.nn.functional as F
lpips_model = lpips.LPIPS(net='alex').to(device)
ssim_model = pytorch_msssim.MS_SSIM(data_range=1.0, size_average=True, channel=3).to(device)
# store z-scores for all test images
# 
om_ssims = []
lpips = []
mses = []
paths = []
gts = []

for batch in dataloader:
    batch = dict_to_device(batch)
    output = pipe(num_inference_steps=num_sampling_steps, noise_timesteps=noise_timesteps, batch=batch)
    images = output["images"]
    latents = output["latents"]
    clean_latents = output["clean_latents"]
    
    ssim = torch.stack([ssim_model(img, gt).squeeze() for img, gt in zip(latents, clean_latents)])
    lpip = torch.stack([lpips_model(img, gt).squeeze() for img, gt in zip(images, batch["image"])])
    mse = F.mse_loss(latents, clean_latents, reduction='none').mean(dim=[1,2,3])
    
    om_ssims.append(torch.tensor([1.]) - ssim.cpu())
    lpips.append(lpip.cpu())
    mses.append(mse.cpu())
    paths.extend(batch["path"])
    gts.append(batch["label"].cpu())
    
om_ssims = torch.cat(om_ssims)
lpips = torch.cat(lpips)
mses = torch.cat(mses)
gts = torch.cat(gt)

# z-score normalization
om_ssims = z_score(om_ssims)
lpips = z_score(lpips)
mses = z_score(mses)
# combined score
combined = (om_ssims + lpips + mses) / 3.0

# dicts of sid: z-score and sid: label
slide_combined, labels = slide_z_scores(combined, paths, method="max", labels=gts)


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /projects/ovcare/users/cindy_shi/miniconda3/envs/uldm/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth


 28%|██▊       | 14/50 [22:58<59:04, 98.46s/it]  


KeyboardInterrupt: 

In [None]:
# evaluation
from sklearn.metrics import roc_auc_score, average_precision_score, balanced_accuracy_score, roc_curve

# todo: randomly set a threshold - 70th percentile, for now
gts = gts.numpy()
combined = combined.numpy()
slide_combined_scores = list(slide_combined.values())
slide_gt = list(labels.values())

def evaluate(gt, scores):
    roc_auc = roc_auc_score(gt, scores)
    pr_auc = average_precision_score(gt, scores)
    thresh = torch.tensor(scores).kthvalue(int(0.7 * len(scores))).values.item()
    bacc = balanced_accuracy_score(gt, (scores >= thresh).astype(int))
    fpr, tpr, _ = roc_curve(gt, scores)
    fpr95 = fpr[tpr >= 0.95][0]
    print(f"Image-level - ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}, FPR95: {fpr95:.4f}, BACC (70th percentile): {bacc:.4f}")
    return roc_auc, pr_auc, fpr95, bacc

evaluate(gts, combined)
evaluate(slide_gt, slide_combined_scores)