In [1]:
import albumentations as A 
import cv2
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import os
import pandas as pd 
import torch 

from einops import rearrange
from importlib import import_module
from skp.toolbox.classes import Ensemble 
from skp.toolbox.functions import load_kfold_ensemble_as_list
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
class ImageDataset(Dataset):

    def __init__(self, df, data_dir, resizer):
        self.df = df
        self.data_dir = data_dir
        self.resizer = resizer
        self.bleed_types = ["any", "epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        files = row.filepath_2dc.split(",")
        img = np.stack([cv2.imread(os.path.join(self.data_dir, "stage_2_train_png", f), 0) for f in files], axis=-1)
        h, w = img.shape[:2]
        img = self.resizer(image=img)["image"]
        img = rearrange(img, "h w c -> c h w")
        img = torch.from_numpy(img).float()
        return {"x": img, "SOPInstanceUID": row.SOPInstanceUID, "bleed_types": [row[bleed] for bleed in self.bleed_types], "h": h, "w": w}

In [3]:
DATA_DIR = "/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/"

In [4]:
cfg_name = "ich.cfg_BHSD_segment_pos_only_2dc_focal"
cfg = import_module(f"skp.configs.{cfg_name}").cfg
cfg.pretrained = False
cfg.freeze_encoder = False
cfg.load_pretrained_encoder = False
cfg.enable_gradient_checkpointing = False
weights_paths = [f"/home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold{i}/checkpoints/last.ckpt" for i in range(5)]
model_list = load_kfold_ensemble_as_list(cfg, weights_paths=weights_paths, device="cuda", eval_mode=True)
model = Ensemble(model_list, output_name="logits", activation_fn="sigmoid")

Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold0/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold1/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold2/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold3/checkpoints/last.ckpt ...
Loading weights from /home/ian/projects/SKP/experiments/ich/ich.cfg_BHSD_segment_pos_only_2dc_focal/f5ba27ea/fold4/checkpoints/last.ckpt ...


In [5]:
df = pd.read_csv(os.path.join(DATA_DIR, "train_slices_with_2dc_kfold.csv"))
pos_df = df.loc[df["any"] == 1]
pos_df.head()

Unnamed: 0,SOPInstanceUID,epidural,intraparenchymal,intraventricular,subarachnoid,subdural,any,filepath,PatientID,StudyInstanceUID,SeriesInstanceUID,filepath_2dc,outer,inner0,inner1,inner2,inner3,inner4,fold
374,ID_47dea86cc,0,0,0,1,1,1,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0005...,ID_4c16e232,ID_c174374b07,ID_002c9733b7,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0004...,4,0,3,2,0,-1,4
375,ID_939425a6b,0,0,0,1,1,1,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0006...,ID_4c16e232,ID_c174374b07,ID_002c9733b7,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0005...,4,0,3,2,0,-1,4
376,ID_bb696a05c,0,0,0,1,1,1,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0007...,ID_4c16e232,ID_c174374b07,ID_002c9733b7,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0006...,4,0,3,2,0,-1,4
377,ID_53be93586,0,0,0,1,1,1,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0008...,ID_4c16e232,ID_c174374b07,ID_002c9733b7,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0007...,4,0,3,2,0,-1,4
378,ID_5cc8ccfb5,0,0,0,1,1,1,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0009...,ID_4c16e232,ID_c174374b07,ID_002c9733b7,ID_4c16e232/ID_c174374b07/ID_002c9733b7/IM0008...,4,0,3,2,0,-1,4


In [6]:
save_dir = os.path.join(DATA_DIR, "segmentation_masks_soft_pseudolabels2")
os.makedirs(save_dir, exist_ok=True)

In [7]:
dataset = ImageDataset(pos_df.reset_index(drop=True), data_dir=DATA_DIR, resizer=A.Resize(512, 512, p=1))
loader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False, num_workers=16)

In [10]:
abbreviate = {
    "any": "any", "epidural": "edh", "intraparenchymal": "iph", "intraventricular": "ivh", "subarachnoid": "sah", "subdural": "sdh"
}

bleed_names = ["any", "epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]

for batch in tqdm(loader, total=len(loader)):
    with torch.inference_mode():
        logits = model({"x": batch["x"].cuda()})
        # turn into list
        logits = [_ for _ in logits]
    for idx, h in enumerate(batch["h"]):
        if h != 512 or batch["w"][idx] != 512:
            logits[idx] = torch.nn.functional.interpolate(logits[idx].unsqueeze(0), size=(h, batch["w"][idx]), mode="bilinear").squeeze(0)
    bleed_types = torch.stack(batch["bleed_types"], dim=1).numpy()
    for sample in range(len(logits)):
        tmp_logits = logits[sample]
        tmp_sop = batch["SOPInstanceUID"][sample]
        tmp_bleeds = bleed_types[sample]
        bleed_present = []
        for idx, bleed in enumerate(tmp_bleeds):
            if bleed == 1:
                bleed_present.append(idx)
        for bleed_idx in bleed_present:
            mask = tmp_logits[bleed_idx].cpu().numpy()
            mask = (mask * 255).astype("uint8")
            fp = os.path.join(save_dir, f"{tmp_sop}_{abbreviate[bleed_names[bleed_idx]].upper()}.png")
            cv2.imwrite(fp, mask)

100%|██████████| 3091/3091 [59:50<00:00,  1.16s/it]
