In [2]:
#!pip install monai lovely-numpy -q --no-index --find-links=../input/vesuvis-downloads
#!python -m pip install -q --no-index --find-links=/kaggle/input/pip-download-for-segmentation-models-pytorch segmentation-models-pytorch
#!python -m pip install -q /kaggle/input/omegaconf222py3/omegaconf-2.2.2-py3-none-any.whl --no-index --find-links=/kaggle/input/omegaconf222py3/
#!pip install -q /kaggle/input/tensordict/tensordict-0.2.1-cp310-cp310-manylinux1_x86_64.whl

In [3]:
#import sys
#sys.path.append("../input/ttach-kaggle/")

In [1]:
import os
import numpy as np
import pandas as pd

import cv2
from glob import glob
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import tensordict

import albumentations as A
import segmentation_models_pytorch as smp
import gc
import monai
import ttach as tta
from typing import Union, Dict, Tuple


import re

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
DATASET_FOLDER = "/hdd/yang/data/kaggle/test"

#is_test = not len(glob(os.path.join(DATASET_FOLDER, "test", "*", "*", "*.tif"))) == 6
is_test = not len(glob(os.path.join(DATASET_FOLDER, "*", "*", "*.tif"))) == 6
# if is_test:
#     datasets = sorted(glob(f"{DATASET_FOLDER}/test/*"))[::-1]
# else:
#     datasets = sorted(glob(f"{DATASET_FOLDER}/train/kidney_2"))
if is_test:
    datasets = sorted(glob(f"{DATASET_FOLDER}/*/"))[::-1]
else:
    datasets = sorted(glob(f"{DATASET_FOLDER}/train/kidney_2"))

print(len(datasets))
print(datasets)

1
['/hdd/yang/data/kaggle/test/kidney5/']


In [4]:
def rename_keys(original_dict, pattern):
    new_dict = {}
    
    for old_key, value in original_dict.items():
        new_key = re.sub(pattern, '', old_key)
        
        new_dict[new_key] = value
    
    return new_dict


def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    rle = ' '.join(str(x) for x in runs)
    if rle=='':
        rle = '1 0'
    return rle

def find_highest_score_filename(file_list):
    highest_score = float('-inf')
    highest_score_filename = None

    for filename in file_list:
        # Extract the score from the filename using regular expression
        match = re.search(r'dice_(\d+\.\d+)', filename)
        if match:
            current_score = float(match.group(1))
            if current_score > highest_score:
                highest_score = current_score
                highest_score_filename = filename

    return highest_score_filename

def to_device(x: torch.Tensor, cuda_id: int = 0) -> torch.Tensor:
    return x.cuda(cuda_id) if torch.cuda.is_available() else x


def load_jit_model(model_path: str, cuda_id: int = 0) -> torch.nn.Module:
    model = torch.jit.load(
        model_path,
        map_location=f"cuda:{cuda_id}" if torch.cuda.is_available() else "cpu",
    )
    return model

In [5]:
predict_on = [
    [
        "Unet3d",
        192,
        "baseline_3d_unet_192_bs4_d4_scaled_pseudo_0.1_random",
        1.0,
    ],
]


class BuildDataset:
    def __init__(self, dataset: str, is_test: bool = True):
        self.ids = []
        self.is_test = is_test

        self.xmin, self.xmax = 0, 0

        self.data_tensor = self.load_volume(dataset)
        self.shape_orig = self.data_tensor.shape
        
    def normilize(self, image: np.ndarray) -> np.ndarray:
        if image.dtype != np.half:
            image = image.astype(np.half, copy=False)
            
        image -= self.xmin
        image /= (self.xmax - self.xmin)
        
        np.clip(image, 0, 1, out=image)
        return image
    
    @staticmethod
    def norm_by_percentile(
        volume: np.ndarray, low: float = 10, high: float = 99.8
    ) -> Tuple:
        xmin = np.percentile(volume, low)
        print(xmin)
        xmax = np.max([np.percentile(volume, high), 1])
        print(xmax)
        return int(xmin), int(xmax)

    def load_volume(self, dataset: str) -> np.ndarray:
        path = os.path.join(dataset, "images", "*.tif")
        
        dataset = sorted(glob(path)) if self.is_test else sorted(glob(path))[:192]

        for p_img in tqdm(dataset):
            path_ = p_img.split(os.path.sep)
            slice_id, _ = os.path.splitext(path_[-1])
            self.ids.append(f"{path_[-3]}_{slice_id}")

        volume = None

        for z, path in enumerate(tqdm(dataset)):
            image = cv2.imread(path, cv2.IMREAD_ANYDEPTH).astype(np.half, copy=False)
            
            if volume is None:
                volume = np.zeros((len(dataset), *image.shape[-2:]), dtype=np.float16)
            volume[z, :, :] = image
            
        self.xmin, self.xmax = self.norm_by_percentile(volume)
        return volume
    
    
class ModelWrapper(torch.nn.Module):
    def __init__(self, base_model):
        super(ModelWrapper, self).__init__()
        self.base_model = base_model

    def forward(self, x):
        return torch.sigmoid(self.base_model(x)).half()

In [6]:
tta_models = []
weights = []
folds2predict = [0, 1]

for model_config in tqdm(predict_on):
    for fold in folds2predict:
        model_path = sorted(
            glob(
                f"logs/train/runs/{model_config[2]}/{fold}/checkpoints/epoch*.ckpt"
            )
        )[-1]
        print(model_path)
        state_dict = rename_keys(
            torch.load(model_path, map_location="cpu")["state_dict"], "net."
        )
        model_base = to_device(
            monai.networks.nets.DynUNet(spatial_dims=3, in_channels=1, out_channels=1, kernel_size=[ [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ] ], strides=[ [ 1, 1, 1 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ] ], upsample_kernel_size=[[ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ]], dropout=0.2)
        )
        model_base.load_state_dict(state_dict)
        model = ModelWrapper(model_base)

        model.eval()

        model = torch.nn.DataParallel(model)

        tta_models.append(
            tta.SegmentationTTAWrapper(
                model.half(), tta.aliases.d4_transform(), merge_mode="mean"
            )
        )

        weights.append(model_config[-1])

  0%|          | 0/1 [00:00<?, ?it/s]

logs/train/runs/baseline_3d_unet_192_bs4_d4_scaled_pseudo_0.1_random/0/checkpoints/epoch_469_dice_0.9545.ckpt
logs/train/runs/baseline_3d_unet_192_bs4_d4_scaled_pseudo_0.1_random/1/checkpoints/epoch_377_dice_0.9649.ckpt


100%|██████████| 1/1 [00:01<00:00,  1.76s/it]


In [7]:
if not os.path.exists(f"preds_3d"):
    os.makedirs(f"preds_3d")

rles, ids = [], []
with torch.no_grad():
    for dataset in datasets:
        folder = dataset.split("/")[-1]
        print(folder)

        test_dataset = BuildDataset(dataset, is_test=is_test)
        print(test_dataset.shape_orig)

        ids += test_dataset.ids
        
        preds = 0
        input_tensor = tensordict.MemmapTensor.from_tensor(torch.from_numpy(test_dataset.normilize(test_dataset.data_tensor.astype(np.half))).unsqueeze(0).unsqueeze(0))
        for tta_model, weight in zip(tta_models, weights):
            preds += weight * monai.inferers.sliding_window_inference(
                inputs=input_tensor, # if is_test else torch.rand(1, 1, 512, 512, 512),
                predictor=tta_model,
                sw_batch_size=2,
                roi_size=(256, 256, 256),
                overlap=0.25,
                padding_mode="reflect",
                mode="gaussian",
                sw_device="cuda",
                device="cpu",
                progress=True,
            ).squeeze().cpu().numpy().astype(np.half) / sum(weights)

        for idx, pred in enumerate(preds):
            cv2.imwrite(f"preds_3d/{test_dataset.ids[idx]}.png", (255*pred).astype(np.uint8))

            
        if is_test:
            del input_tensor, test_dataset, preds
            gc.collect()
            torch.cuda.empty_cache()





100%|██████████| 2363/2363 [00:00<00:00, 822433.02it/s]
  image = cv2.imread(path, cv2.IMREAD_ANYDEPTH).astype(np.half, copy=False)
100%|██████████| 2363/2363 [00:25<00:00, 91.54it/s]


5280.0
10552.0
(2363, 1330, 1598)


100%|██████████| 336/336 [38:00<00:00,  6.79s/it]
100%|██████████| 336/336 [39:31<00:00,  7.06s/it]


In [8]:
if is_test:
    del tta_models
    gc.collect()
    torch.cuda.empty_cache()

In [9]:
predict_on = [
     ["UnetPlusPlus", "tu-tf_efficientnet_b5", 1, "UnetPlusPlus_tu-tf_efficientnet_b5_BoundaryDoULoss_size_1_512_bs32_hard_pseudo_v2", 3., "scse", "senet-models", 800],  #0878 new sampling + cutmix
]


tta_models = []
weights = []

use_top_only = True #True
use_best = False # False
folds2predict = [0, 1]

use_tta = True

TH3d = 0.5 #()
TH2d = 0.05



# Dataset

In [10]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, in_channels=3, is_test=False):
        self.window = in_channels // 2
        self.is_test = is_test
        self.ids = []

        self.data_tensor = self.load_volume(dataset)
        self.shape_orig = self.data_tensor.shape

        padding = (
            (self.window, self.window),
        ) * self.data_tensor.ndim

        self.padding = tuple(
            (max(0, before), max(0, after)) for (before, after) in padding
        )
        self.data_tensor = np.pad(
            self.data_tensor, padding, mode="constant", constant_values=0
        )

    def __len__(self):
        return sum(self.shape_orig) if self.is_test else self.shape_orig[0]

    def normilize(self, image):
        image = (image - self.xmin) / (
                self.xmax - self.xmin)
        image = np.clip(image, 0, 1)
        return image.astype(np.float32)
    
    @staticmethod
    def norm_by_percentile(volume, low=10, high=99.8):
        xmin = np.percentile(volume, low)
        print(xmin)
        xmax = np.max([np.percentile(volume, high), 1])
        print(xmax)
        return xmin, xmax

    def load_volume(self, dataset):
        path = os.path.join(dataset, "images", "*.tif")
        dataset = sorted(glob(path)) if self.is_test else sorted(glob(path))[:192]
        for p_img in tqdm(dataset):
            path_ = p_img.split(os.path.sep)
            slice_id, _ = os.path.splitext(path_[-1])
            self.ids.append(f"{path_[-3]}_{slice_id}")
            
        volume = None

        for z, path in enumerate(tqdm(dataset)):
            image = cv2.imread(path, cv2.IMREAD_ANYDEPTH)
            image = np.array(image, dtype=np.uint16)
            if volume is None:
                volume = np.zeros((len(dataset), *image.shape[-2:]), dtype=np.uint16)
            volume[z] = image
        self.xmin, self.xmax = self.norm_by_percentile(volume)
        return volume

    def __getitem__(self, idx):
        # Determine which axis to sample from based on the index
        if idx < self.shape_orig[0]:
            idx = idx + self.window
            slice_data = self.normilize(
                self.data_tensor[
                    idx - self.window : 1 + idx + self.window, :, :
                ].transpose(1, 2, 0)[self.window:-self.window, self.window:-self.window, :]
            )
            axis = "X"
            idx -= self.window
        elif idx < self.shape_orig[0] + self.shape_orig[1]:
            idx -= (self.shape_orig[0] - self.window)
            slice_data = self.normilize(
                self.data_tensor[
                    :, idx - self.window : 1 + idx + self.window, :
                ].transpose(0, 2, 1)[self.window:-self.window, self.window:-self.window, :]
            )
            axis = "Y"
            idx -= self.window

            
        else:
            idx -= (
                self.shape_orig[0]
                + self.shape_orig[1]
                - self.window
            ) 
            
            slice_data = self.normilize(
                self.data_tensor[
                    :, :, idx - self.window : 1 + idx + self.window
                ][self.window:-self.window, self.window:-self.window, :]
            )
            axis = "Z"
            idx -= self.window

        slice_data = torch.tensor(slice_data.transpose(2, 0, 1))

        return {
            "slice": slice_data.half(),
            "slice_index": idx,
            "axis": axis
        }

In [11]:
in_chans = []
resolutions = []

for model_config in tqdm(predict_on):
    print(model_config)
    for fold in folds2predict:
        print(model_config[6])
        print(model_config[3])
        model_path = sorted(glob(f"logs/train/runs/{model_config[3]}/{fold}/checkpoints/last*.ckpt"))[-1]
        print(f"use_top_only, loading: {model_path}")
        state_dict = rename_keys(torch.load(model_path, map_location="cpu")["state_dict"], "net.")
        model = to_device(smp.create_model(arch=model_config[0], encoder_name=model_config[1], in_channels=model_config[2], encoder_weights=None, decoder_attention_type=model_config[5]))
        model.load_state_dict(state_dict)
        model.eval()

        model = torch.nn.DataParallel(model)

        if use_tta:
            tta_models.append(tta.SegmentationTTAWrapper(model.half(), tta.aliases.flip_transform(), merge_mode='mean')) #flip_transform d4_transform
        else:
            tta_models.append(model)

        weights.append(model_config[4])
        in_chans.append(model_config[2])
        resolutions.append(model_config[7])


  0%|          | 0/1 [00:00<?, ?it/s]

['UnetPlusPlus', 'tu-tf_efficientnet_b5', 1, 'UnetPlusPlus_tu-tf_efficientnet_b5_BoundaryDoULoss_size_1_512_bs32_hard_pseudo_v2', 3.0, 'scse', 'senet-models', 800]
senet-models
UnetPlusPlus_tu-tf_efficientnet_b5_BoundaryDoULoss_size_1_512_bs32_hard_pseudo_v2
use_top_only, loading: logs/train/runs/UnetPlusPlus_tu-tf_efficientnet_b5_BoundaryDoULoss_size_1_512_bs32_hard_pseudo_v2/0/checkpoints/last.ckpt
senet-models
UnetPlusPlus_tu-tf_efficientnet_b5_BoundaryDoULoss_size_1_512_bs32_hard_pseudo_v2
use_top_only, loading: logs/train/runs/UnetPlusPlus_tu-tf_efficientnet_b5_BoundaryDoULoss_size_1_512_bs32_hard_pseudo_v2/1/checkpoints/last.ckpt


100%|██████████| 1/1 [00:09<00:00,  9.15s/it]


In [12]:
del state_dict

# Inference

In [13]:
def merge_preds(mask1, mask2):
    binary_mask = (255 * (mask1 > TH2d)).astype(np.uint8)
    edged = cv2.Canny(binary_mask, 12, 200, L2gradient=True)
    contours, hierarchy = cv2.findContours(edged,  
        cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
    interest_mask = np.zeros_like(binary_mask)

    if len(contours) > 0:
        all_contours = np.vstack([contours[i] for i in range(len(contours))])
        hull = cv2.convexHull(all_contours)
        cv2.drawContours(interest_mask, [hull], -1, (1), thickness=cv2.FILLED)

        interest_mask = cv2.dilate(interest_mask, np.ones((5, 5), np.uint8), iterations=5) 
        return ((interest_mask * mask2 + mask1) > TH2d + TH3d).astype(np.uint8)   
    else:
        return (mask1 > TH2d).astype(np.uint8)

In [14]:
rles, ids = [], []


with torch.no_grad():
    for dataset in datasets:
        test_dataset = BuildDataset(dataset, is_test=is_test, in_channels=3) # TODO: refactor this
        test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4, shuffle=False, pin_memory=False)

        y_preds = np.zeros(test_dataset.shape_orig, dtype=np.half)
        ids += test_dataset.ids

        pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc=f'Inference {dataset}')
        for step, batch in pbar:
            images = to_device(batch["slice"])
            
            axis = batch["axis"][0]
            idx = batch["slice_index"].numpy()[0]

            preds = 0
            for tta_model, weight, in_chan, resolution in zip(tta_models, weights, in_chans, resolutions):
                preds += weight * monai.inferers.sliding_window_inference(
                    inputs=images.half() if in_chan != 1 else images[:, 1,...].unsqueeze(0).half(), # TODO: Refactor this
                    predictor=tta_model.half(),
                    sw_batch_size=8,
                    roi_size=(resolution, resolution),
                    overlap=0.25,
                    padding_mode="reflect",
                    mode="gaussian",
                    sw_device="cuda",
                    device="cuda",
                    progress=False,
                )
            if axis == "X":
                y_preds[idx, :, :] += ((preds / sum(weights)).squeeze().sigmoid().cpu().numpy() / 3.).astype(np.half)
            elif axis == "Y":
                y_preds[:, idx, :] += ((preds / sum(weights)).squeeze().sigmoid().cpu().numpy() / 3.).astype(np.half)
            elif axis == "Z":
                y_preds[:, :, idx] += ((preds / sum(weights)).squeeze().sigmoid().cpu().numpy() / 3.).astype(np.half)
        
        for idx, pred_2d in enumerate(y_preds):     
            pred_3d = cv2.imread(f"preds_3d/{test_dataset.ids[idx]}.png", 0) / 255.
            merge_p = merge_preds(pred_2d, pred_3d)          
            rles.append(rle_encode(merge_preds(pred_2d, pred_3d)))

        del test_dataset, test_loader, y_preds
        gc.collect()
        torch.cuda.empty_cache()

100%|██████████| 2363/2363 [00:00<00:00, 835748.41it/s]
100%|██████████| 2363/2363 [00:08<00:00, 288.10it/s]


5279.0
10551.0


Inference /hdd/yang/data/kaggle/test/kidney5/: 100%|██████████| 5291/5291 [4:26:16<00:00,  3.02s/it]  


In [15]:
del tta_models, tta_model, batch, preds, images, model
gc.collect()
torch.cuda.empty_cache()


In [16]:
!rm -rf preds_3d

In [17]:
submission = pd.DataFrame.from_dict({
    "id": ids,
    "rle": rles
})
submission.to_csv("submission.csv", index=False)

In [None]:
submission