In [None]:
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch

from importlib import import_module
from skimage.morphology import remove_small_objects
from tqdm import tqdm

from skp.toolbox.classes import Ensemble
from skp.toolbox.functions import load_kfold_ensemble_as_list, plot_3d_image_side_by_side

In [9]:
def create_ich_mask(probas, bleeds_present, init_thresh, min_thresh, increment=0.1, min_size=50, verbose=False):
    # probas.shape = (6, H, W)
    all_bleed, bleed_subtypes = probas[0], probas[1:]
    mask = np.zeros_like(all_bleed)
    bleed_present = np.where(bleeds_present)[0]
    # if there is a bleed at a given pixel, what is the most likely bleed subtype
    # among the bleeds present on this slice as noted in the ground truth
    bleed_subtypes = np.argmax(bleed_subtypes[bleed_present], axis=0)
    all_bleed_thresh = init_thresh
    while mask.sum() == 0 and all_bleed_thresh >= min_thresh:
        if all_bleed_thresh < init_thresh:
            if verbose:
                print(f"Empty bleed mask. Reducing bleed threshold to {all_bleed_thresh} ...")
        for bleed_idx, bleed in enumerate(bleed_present):
            mask[bleed_subtypes == bleed_idx] = bleed + 1
            mask[all_bleed < all_bleed_thresh] = 0
        if mask.sum() > 0:
            # remove small objects
            binary_mask = mask > 0
            binary_mask = remove_small_objects(binary_mask, min_size=min_size)
            mask[~binary_mask] = 0
        all_bleed_thresh -= increment
        all_bleed_thresh = float(f"{all_bleed_thresh:0.2f}")
    if mask.sum() == 0:
        if verbose:
            print(f"Reached minimum bleed threshold of {min_thresh}. Returning empty mask ...")
    return mask 

In [None]:
cfg_name = "ich.cfg_slice_segment_2dc_pos_only_sigmoid"
cfg = import_module(f"skp.configs.{cfg_name}").cfg
weights_paths = [cfg.save_dir + f"/{cfg_name}/4e607791/fold{fold}/checkpoints/last.ckpt" for fold in range(5)]
model_list = load_kfold_ensemble_as_list(cfg, weights_paths, device="cuda", eval_mode=True)
model = Ensemble(model_list, output_name="logits", activation_fn="sigmoid")

In [None]:
df = pd.read_csv("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/train_slices_with_2dc_kfold.csv")
pos_df = df.loc[df["any"] == 1]
pos_df.head()

In [5]:
bhsd_df = pd.read_csv("/mnt/stor/datasets/BHSD/train_positive_slices_png_kfold.csv")
pos_df = pos_df.loc[~pos_df.PatientID.isin(list(set(bhsd_df.PatientID) & set(pos_df.PatientID)))]

In [6]:
save_dir = "/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/generated_segmentation_masks_exclude_bhsd/"

In [None]:
resizer = A.Resize(512, 512, p=1)

for row_idx, row in tqdm(pos_df.iterrows(), total=len(pos_df)):
    files = row.filepath_2dc.split(",")
    img = np.stack([cv2.imread(os.path.join("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train_png/", f), cv2.IMREAD_GRAYSCALE) for f in files], axis=-1)
    h, w = img.shape[:2]
    resized = False
    if h != 512 or w != 512:
        resized = True
        img = resizer(image=img)["image"]
    img = img.transpose(2, 0, 1)
    img = torch.from_numpy(img).unsqueeze(0).float().cuda()
    with torch.inference_mode():
        out = model({"x": img})
        if resized:
            out = torch.nn.functional.interpolate(out, size=(h, w), mode="bilinear")
        out = out[0].cpu().numpy()
    bleed_types = row[["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]].values
    mask = create_ich_mask(out, bleed_types, init_thresh=0.5, min_thresh=0.1, min_size=50)
    if mask.sum() == 0:
        continue
    # save mask as middle file 
    save_path = os.path.join(save_dir, files[1])
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    _ = cv2.imwrite(save_path, mask.astype("uint8"))

In [None]:
print(img.shape, h, w, mask.shape)

In [None]:
stack0 = np.stack([cv2.imread(os.path.join("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train_png/", f), cv2.IMREAD_GRAYSCALE) for f in tmp_df.filepath], axis=0)
stack = torch.from_numpy(stack0)
# make 2Dc
stack_2dc = torch.cat([stack[0].unsqueeze(0), stack, stack[-1].unsqueeze(0)])
stack_2dc = torch.stack([stack_2dc[i:i+len(stack_2dc) - 2] for i in range(3)], dim=1)
print(stack_2dc.shape)

In [41]:
with torch.inference_mode():
    out = model({"x": stack_2dc.float().cuda()})

In [42]:
y = out.argmax(1).cpu().numpy()

In [None]:
plot_3d_image_side_by_side(stack0, y, num_images=len(y), axis=0)