In [12]:
import albumentations as A 
import cv2
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import os
import pandas as pd 
import torch 

from einops import rearrange, reduce
from importlib import import_module
from skp.toolbox.classes import Ensemble 
from skp.toolbox.functions import load_kfold_ensemble_as_list
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [9]:
def calculate_dice_over_thresholds(p, t, thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]):
    # p.shape = t.shape = (c, h, w)
    assert p.shape == t.shape
    p = torch.stack([p >= th for th in thresholds])
    t = torch.stack([t] * len(thresholds))
    intersection = reduce(p * t, "n c h w -> n c", "sum")
    denominator = reduce(p + t, "n c h w -> n c", "sum")
    dice = (2 * intersection) / denominator
    return dice, t[0].sum((1, 2))

In [4]:
DATA_DIR = "/mnt/stor/datasets/BHSD/"

In [2]:
cfg_name = "ich.cfg_BHSD_segment_pos_only_2dc_focal"
cfg = import_module(f"skp.configs.{cfg_name}").cfg
cfg.pretrained = False
cfg.freeze_encoder = False
cfg.load_pretrained_encoder = False
cfg.enable_gradient_checkpointing = False
weights_paths = [f"/home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold{i}/checkpoints/last.ckpt" for i in range(5)]
model_list = load_kfold_ensemble_as_list(cfg, weights_paths=weights_paths, device="cuda", eval_mode=True)

Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold0/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold1/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold2/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold3/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold4/checkpoints/last.ckpt ...


In [33]:
df = pd.read_csv(os.path.join(DATA_DIR, "train_positive_slices_png_kfold.csv"))
thresholds = np.arange(0.05, 1.00, 0.05)
thresholds

array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,
       0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95])

In [34]:
dice_list, labels_list = [], []
for row_idx, row in tqdm(df.iterrows(), total=len(df)):
    image_path = os.path.join(DATA_DIR, "png", row["image"])
    mask_path = os.path.join(DATA_DIR, "png", row["label"])
    img = cv2.imread(image_path, cv2.IMREAD_COLOR) # img saved as RGB
    h, w = img.shape[:2]
    img = cv2.resize(img, (512, 512))
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    model = model_list[row.fold]
    img = rearrange(img, "h w c -> 1 c h w")
    img = torch.from_numpy(img).float().cuda()
    with torch.inference_mode():
        y = model({"x": img})
        y = y["logits"].sigmoid()
        if h != 512 or w != 512:
            y = torch.nn.functional.interpolate(y, size=(h, w), mode="bilinear")
        y = y.cpu()[0]
    mask = torch.from_numpy(mask).long()
    mask = torch.nn.functional.one_hot(mask, 6).float()
    mask = rearrange(mask, "h w c -> c h w")
    mask[0] = 1 - mask[0]
    dice, labels = calculate_dice_over_thresholds(y, mask, thresholds)
    dice_list.append(dice)
    labels_list.append(labels)

100%|██████████| 2368/2368 [03:49<00:00, 10.32it/s]


In [35]:
dice_dict = {}
for i in range(6):
    # only include dice if label is present
    dice_dict[i] = torch.stack([d[:, i] for idx, d in enumerate(dice_list) if labels_list[idx][i] > 0])

In [36]:
for k, v in dice_dict.items():
    print(k, len(v))

0 2368
1 181
2 888
3 713
4 976
5 765


In [37]:
for k, v in dice_dict.items():
    print(k, f"{v.mean(0).amax().item():0.4f}", thresholds[v.mean(0).argmax().item()])

0 0.5787 0.4
1 0.3848 0.1
2 0.6689 0.35000000000000003
3 0.5693 0.4
4 0.3330 0.35000000000000003
5 0.4135 0.3
