In [1]:
from typing import Optional, Union
from glob import glob
from tqdm.auto import tqdm
import tifffile as tiff
import cv2
import pandas as pd

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


class SweepCube(Dataset):
    def __init__(
            self, 
            data: torch.Tensor, 
            patch_size: tuple[int, int, int],
            stride: Optional[tuple[int, int, int]] = None
        ):
        assert data.ndim == 4, "Data must be 4-dimensional"

        self.data = data
        self.patch_size = patch_size
        self.stride = stride if stride is not None else patch_size

        self.patches = (
            (self.data.shape[1] - self.patch_size[0]) // self.stride[0] + 1,
            (self.data.shape[2] - self.patch_size[1]) // self.stride[1] + 1,
            (self.data.shape[3] - self.patch_size[2]) // self.stride[2] + 1,
        )

    def __len__(self):
        return (self.patches[0] * self.patches[1] * self.patches[2])

    def __getitem__(self, i):
        x, y, z = np.unravel_index(i, self.patches)
        x, y, z = x * self.stride[0], y * self.stride[1], z * self.stride[2]

        return self.data[
            :,
            x:x+self.patch_size[0],
            y:y+self.patch_size[1],
            z:z+self.patch_size[2],
        ], torch.tensor([x, y, z])

class Bayes(nn.Module):
    def __init__(self):
        super().__init__()
        self._registered_bayesian_modules: list["Bayes"] = []

    def register_bayes(self, module: Union[list[nn.Module], nn.Module]):
        if 'Bayes' in [b.__name__ for b in type(module).__bases__]:
            self._registered_bayesian_modules.append(module)
        else:
            # check if iterable
            try:
                for u in module: self.register_bayes(u)
            except TypeError:
                pass

    def penalty(self):
        raise NotImplementedError()

    def penalize(self, alpha: float = 1e-3):
        """
        alpha: the penalty coefficient
        """
        return sum(u.penalize(alpha) for u in self._registered_bayesian_modules)

    def decay_var(self, gamma: float = 0.5):
        """
        gamma: variance decay rate
        """
        for u in self._registered_bayesian_modules: u.decay_var(gamma)
    
    def rebase(self):
        """
        Rebase the parameters to the current mean
        """
        for u in self._registered_bayesian_modules: u.rebase()

class Conv2DNormed(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer(
            'norm_term',
            torch.tensor(np.sqrt(
                np.max([self.in_channels, self.out_channels])
            ))
        )
    
    def forward(self, x: torch.Tensor):
        return super().forward(x) / self.norm_term

class ConvBlock(Bayes):
    def __init__(
            self,
            in_f: int,
            out_f: int,
            kernel_size: int = 3,

            conv: nn.Module = nn.Conv2d,
            activation: nn.Module = nn.GELU,
            norm_fn: Optional[nn.Module] = None,
            dropout: Optional[tuple[nn.Module, float]] = None,

            padding: int = 1,
            stride: int = 1,
            block_depth: int = 4,
            **kwargs
        ):
        super().__init__()
        self.out_f = out_f
        def build_block(in_f, out_f, stride=1):
            L = []
            if dropout is not None:
                L.append(dropout[0](dropout[1]))
            L.append(conv(in_f, out_f, kernel_size, padding=padding, stride=stride, **kwargs))
            if norm_fn is not None:
                L.append(norm_fn(out_f))
            L.append(activation())
            return nn.Sequential(*L)

        self.layers = nn.ModuleList([
            build_block(
                in_f if i == 0 else out_f,
                out_f,
                stride=stride if i == 0 else 1
                ) for i in range(block_depth)
        ])
        for l in self.layers: self.register_bayes(l)

    def forward(self, x: torch.Tensor):
        zz = self.layers[0](x)
        z = zz
        for l in self.layers[1:]:
            z = l(z) + z # Residual connection
        return z + zz # Jump connection

class UNetCrossBlock(Bayes):
    def __init__(
            self,
            layers: list[int],
            into_stage: int,

            block_depth: int = 1,
            connect_depth: int = 64,

            conv: nn.Module = nn.Conv3d,
            activation: nn.Module = nn.GELU,
            pool_fn: nn.Module = nn.MaxPool3d,
            resize_kernel: tuple = (2, 2, 2),
            upsample_mode: str = 'trilinear',
            norm_fn: Optional[nn.Module] = None,
            dropout: Optional[tuple[nn.Module, float]] = None,

            **kwargs
    ):
        super().__init__()
        assert 0 <= into_stage < len(layers), f"into_layer must be in [0, {len(layers)})"
        self.num_stages = len(layers)

        self.resize_blocks = nn.ModuleList([])
        for i in range(len(layers)):
            block = nn.Sequential()                
            if i < into_stage:
                block.append(pool_fn((*(k ** (into_stage - i) for k in resize_kernel),), ceil_mode=True))
            if i > into_stage:
                block.append(nn.Upsample(scale_factor=(*(k ** (i - into_stage) for k in resize_kernel),), mode=upsample_mode))
            block.append(
                ConvBlock(
                    connect_depth * self.num_stages if i > into_stage and i != self.num_stages - 1 else layers[i],
                    connect_depth,
                    conv=conv,
                    activation=activation,
                    block_depth=block_depth,
                    dropout=dropout,
                    norm_fn=norm_fn,
                    **kwargs
                ),
            )
            self.resize_blocks.append(block)
        self.full_conv_block = ConvBlock(
            connect_depth * len(layers),
            connect_depth * len(layers),
            conv=conv,
            activation=activation,
            block_depth=block_depth,
            dropout=dropout,
            norm_fn=norm_fn,
            **kwargs
        )

        for l in self.resize_blocks: self.register_bayes(l)
        self.register_bayes(self.full_conv_block)


    def forward(self, xs: list[torch.Tensor]):
        assert len(xs) == self.num_stages, f"expected input to be of length {self.num_stages}, but got {len(xs)}"
        zs = []
        for i in range(len(xs)):
            zs.append(self.resize_blocks[i](xs[i]))
        return self.full_conv_block(torch.cat(zs, dim=1))

# Architecture based on : https://arxiv.org/abs/2004.08790
class UNet3P(Bayes):
    def __init__(
            self,
            in_f: int = 1,
            layers: list[int] = [64, 128, 256, 512, 1024],
            out_f: int = 1,

            block_depth: int = 4,
            connect_depth: int = 64,

            conv: nn.Module = nn.Conv3d,
            activation: nn.Module = nn.GELU,
            pool_fn: nn.Module = nn.MaxPool3d,
            resize_kernel: tuple = (2, 2, 2),
            upsample_mode: str = 'trilinear',
            norm_fn: nn.Module = None,
            dropout: tuple[nn.Module, float] = None,
            
            input_noise: Optional[float] = 0.1,
            **kwargs
        ):
        super().__init__()
        if input_noise is not None:
            self.register_buffer('input_noise', torch.tensor(input_noise))
        c = in_f

        self.input_norm = norm_fn(in_f)
        self.down_blocks = nn.ModuleList([])
        for i, l in enumerate(layers):
            block = nn.Sequential()
            if i != 0:
                block.append(pool_fn(resize_kernel))
            block.append(
                ConvBlock(
                    c,
                    l,
                    conv=conv,
                    activation=activation,
                    block_depth=block_depth,
                    dropout=dropout,
                    norm_fn=norm_fn,
                    **kwargs
                )
            )
            self.down_blocks.append(block)
            c = l
        
        self.cross_blocks = nn.ModuleList([])
        for i in range(len(layers) - 1):
            self.cross_blocks.append(
                UNetCrossBlock(
                    layers,
                    i,
                    block_depth=block_depth,
                    connect_depth=connect_depth,
                    conv=conv,
                    activation=activation,
                    pool_fn=pool_fn,
                    resize_kernel=resize_kernel,
                    upsample_mode=upsample_mode,
                    norm_fn=norm_fn,
                    dropout=dropout,
                    **kwargs
                )
            )
    
        self.out_blocks = nn.ModuleList([])
        for i in reversed(range(len(layers) - 1)):
            L = []
            if i == 0:
                L.append(
                    ConvBlock(
                        connect_depth * len(layers),
                        connect_depth * len(layers),
                        1,
                        conv=conv,
                        padding=0,
                        activation=activation,
                        block_depth=block_depth,
                        dropout=dropout,
                        norm_fn=norm_fn,
                        **kwargs
                    )
                )
            L.append(
                conv(
                    connect_depth * len(layers),
                    out_f,
                    1,
                    **kwargs
                )
            )
            self.out_blocks.append(nn.Sequential(*L))

        self.mask_blocks = nn.ModuleList([
            pool_fn((*(k ** (i + 1) for k in resize_kernel),))
            for i in reversed(range(len(layers) - 2))
        ])

        for l in self.down_blocks: self.register_bayes(l)
        for l in self.cross_blocks: self.register_bayes(l)
        for l in self.out_blocks: self.register_bayes(l)

    def _prep_input(self, x: torch.Tensor) -> torch.Tensor:
        z = self.input_norm(x)
        if hasattr(self, 'input_noise') and self.training:
            z += torch.randn_like(z) * self.input_noise
        return z

    def _down_fwd(self, x: torch.Tensor) -> list[torch.Tensor]:
        down_agg = []
        for i in range(len(self.down_blocks)):
            x = self.down_blocks[i](x)
            down_agg.append(x)
        return down_agg

    def forward(self, x: torch.Tensor, up_depth: Optional[int] = None) -> list[torch.Tensor]:
        up_depth = len(self.cross_blocks) if up_depth is None else up_depth
        assert up_depth <= len(self.cross_blocks), f"up_depth must be in [0, {len(self.cross_blocks)}]"

        z = self._prep_input(x)
        down_agg = self._down_fwd(z)
        
        up_agg = []
        for i in list(reversed(range(len(self.cross_blocks))))[:up_depth]:
            cross_input = down_agg[:i + 1] + list(reversed([u for u in up_agg])) + [down_agg[-1]]
            up_agg.append(self.cross_blocks[i](cross_input))

        return [
            o(x)
            # for o, x in zip(self.out_blocks, [down_agg[-1]] + up_agg)
            for o, x in zip(self.out_blocks[:up_depth], up_agg)
        ]

    def deep_masks(self, y: torch.Tensor) -> list[torch.Tensor]:
        masks = []
        for i in range(len(self.mask_blocks)):
            masks.append(self.mask_blocks[i](y))
        return masks + [y]

class Patch2p5D(nn.Module):
    def __init__(self, model: UNet3P):
        super().__init__()
        self.model = model
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4])
        return self.model(z)[-1]


class ScanInference2p5D(nn.Module):
    def __init__(self, patch_fn, batch_size: int, quick: bool = False):
        super().__init__()
        self.patch_fn = patch_fn
        self.batch_size = batch_size
        
        self.patch_size = (8, 256, 256)
        self.perms = torch.Tensor([
            (0, 1, 2),
            (0, 2, 1),
            (1, 0, 2),
            (1, 2, 0),
            (2, 0, 1),
            #(2, 1, 0), 5 perms to make self.pass_max a nice number
        ]).int()
        if quick:
            self.perms = self.perms[:1]
        
        self.register_buffer('pass_max', torch.tensor(255 / len(self.perms)))
    
    def _forward(self, scan: torch.Tensor, device: torch.device = 'cpu') -> torch.Tensor:
        agg_pred = torch.zeros_like(scan)
        scan_loader = DataLoader(
            SweepCube(scan, self.patch_size, stride=(1, self.patch_size[1], self.patch_size[2])),
            batch_size=self.batch_size,
            shuffle=True
        )

        for x, positions in scan_loader:
            x = x.to(device).float()
            pred = F.sigmoid(self.patch_fn(x)).cpu()
            for p, pos in zip(pred, positions):
                agg_pred[
                    :,
                    pos[0] + self.patch_size[0]//2,
                    pos[1]:pos[1] + self.patch_size[1],
                    pos[2]:pos[2] + self.patch_size[2],
                ] += ((p.squeeze(1) * self.pass_max).round()).byte()#(p.squeeze(1) * 255).int
        
        return agg_pred

    def forward(self, scan: torch.Tensor, device: torch.device = 'cpu') -> torch.Tensor:
        agg_pred = torch.zeros_like(scan)

        for perm in self.perms:
            out = self._forward(scan.permute(0, *perm+1), device).permute(0, *torch.argsort(perm)+1)
            agg_pred += out
            del out
        return agg_pred

def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

def id_from_pth(pth: str):
    parts = pth.split("/")[-3:]
    parts.pop(1)
    return "_".join(parts)[:-4]

def proprocess(_scan: np.ndarray):
    scan = _scan.astype(np.float32)
    smin, smax = np.min(scan), np.max(scan)
    scan = (255 * (scan - smin) / (smax - smin)).astype(np.uint8)
    scan = 255 - scan
    clahe = cv2.createCLAHE(clipLimit=40.0, tileGridSize=(8, 8))
    return clahe.apply(scan)

def load_slice(pth, preprocess_fn):
    return torch.tensor(
        proprocess(tiff.imread(pth))
    )

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# data_dir = '/kaggle/input/blood-vessel-segmentation'
# model_pth = "/kaggle/input/model-weights/model.pt"
data_dir = '/root/data'
model_pth = './model.pt'

is_submit= True
try:
    is_submit = len(glob(f"{data_dir}/test/kidney_5/images/*.tif"))!=3
except:
    pass
scan_folders = glob(data_dir + "/*/")
if not is_submit:
    scan_folders = [
        f'{data_dir}/train/kidney_2/',
        f'{data_dir}train/kidney_3_sparse/'
    ]


# TODO: DELETE THIS FOR REAL SUBMISSION
is_submit = True
scan_folders = [
    f'{data_dir}/train/kidney_2/',
]

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
patcher = Patch2p5D(
    UNet3P(
        in_f=8,
        layers=[16, 32, 32, 32, 64, 64, 64],
        block_depth=4,
        connect_depth=6,
        conv=Conv2DNormed,
        pool_fn=nn.MaxPool2d,
        resize_kernel=(2,2),
        upsample_mode='bilinear',
        norm_fn=nn.BatchNorm2d,
        dropout=(nn.Dropout2d, 0.1)
    )
).to(device)
patcher.model.load_state_dict(torch.load(model_pth, map_location=device))
patcher.requires_grad_(False)
patcher.eval()

inference = ScanInference2p5D(
    patcher,
    batch_size=64,
    quick=False
)

In [3]:
submission_list = []
for scan_fn in scan_folders:
    slices, ids = [], []
    print(f"loading scan {scan_fn}")
    for pth in sorted(glob(scan_fn + "images/*.tif")):
        slices.append(load_slice(pth, proprocess))
        ids.append(id_from_pth(pth))
    if len(slices) == 0:
        continue
    scan = torch.stack(slices).unsqueeze(0)

    print("doing inference...")
    pmask = inference(scan, device).squeeze(0) > (255 / 2)

    print("aggregating slice rle encodings...")
    for id, smask in zip(ids, pmask):
        submission_list.append(
            pd.DataFrame(data={
                'id'  : id,
                'rle' : rle_encode(smask.numpy()),
            },index=[0])
        )
    del scan
    del pmask
    del slices
    del ids

print("saving aggregate dataframe of rle encodings...")
submission_df = pd.concat(submission_list)
submission_df.to_csv('submission.csv', index=False)

loading scan /root/data/train/kidney_2/
doing inference...
aggregating slice rle encodings...
loading scan /root/datatrain/kidney_3_sparse/
saving aggregate dataframe of rle encodings...


In [2]:
submission_df = pd.read_csv('submission.csv')

def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

# slices = []
# for row in submission_df.itertuples():
#     slices.append()
pmask2 = torch.stack(
    [torch.tensor(rle_decode(row.rle, (1041, 1511)))
    for row in submission_df.itertuples()]
).unsqueeze(0).bool()

In [3]:
import sys
sys.path.append('..')
import util as otil
true_mask = torch.load('/root/data/cache/train/kidney_2/labels.pt')


In [15]:
true_mask.shape

torch.Size([1, 2217, 1041, 1511])

In [4]:
otil.DiceScore()(pmask2.bool(), true_mask)

tensor(0.8748)

In [4]:
# pmask = out_mask > 0.5
det_mask = torch.cat([ #collect TP, FP, FN
    true_mask * pmask2, # true positive
    (~true_mask) * pmask2, # false positive
    true_mask * (~pmask2) # false negative
], dim=0)

In [6]:
det_mask[0].sum(), det_mask[1].sum(), det_mask[2].sum()

(tensor(11605204), tensor(639534), tensor(2683305))

In [11]:
true_mask.max()

tensor(True)

In [5]:
det_mask.dtype

torch.bool

In [9]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ipywidgets as widgets

class Display:
    def __init__(self, scan: torch.Tensor = None, mask: torch.Tensor = None):
        self.scan = scan
        self.mask = mask

    def _view_slice(self, i: int, slice_dim: int, ax: plt.Axes = None):
        ax.set_facecolor('black')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        slice_idx = [slice(None), slice(None), slice(None)]
        slice_idx[slice_dim] = i

        if self.scan is not None:
            ax.imshow(self.scan[0][tuple(slice_idx)], cmap='gray')
        if self.mask is not None:
            if self.mask.shape[0] == 1:
                ax.imshow(self.mask[0][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'blue']), alpha=0.5)
            else:
                ax.imshow(self.mask[0][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'green']), alpha=0.5)
                ax.imshow(self.mask[1][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'red']), alpha=0.5)
                ax.imshow(self.mask[2][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'yellow']), alpha=0.5)

    @staticmethod
    def _view_slices(i: int, displays: list['Display'], slice_dim: int):
        _, axs = plt.subplots(1, len(displays), figsize=(15, 15))
        if len(displays) == 1:
            axs = [axs]
        for ax, display in zip(axs, displays):
            display._view_slice(i, slice_dim, ax)

    @staticmethod
    def view(displays: list['Display'], slice_dim: int = 0):
        slider_max = displays[0].scan.shape[slice_dim+1] - 1 if displays[0].scan is not None else displays[0].mask.shape[slice_dim+1] - 1
        slider  = widgets.IntSlider(min=0, max=slider_max, step=1, value=0)
        widgets.interact(Display._view_slices, i=slider, displays=widgets.fixed(displays), slice_dim=widgets.fixed(slice_dim))

In [6]:
det_mask.shape[0]

3

In [10]:
Display.view([Display(mask=det_mask)])

interactive(children=(IntSlider(value=0, description='i', max=2216), Output()), _dom_classes=('widget-interactâ€¦

In [None]:
tr

In [None]:
1

In [17]:
slices = glob(scan_folders[0] + "images/*.tif")
slices.sort()
slices == glob(scan_folders[0] + "images/*.tif")

True

In [20]:
x = ["0001.tif", "0000.tif"]
x.sort()
x

['0000.tif', '0001.tif']