# 🛠 Install Libraries

In [1]:
# !pip install -q ../input/mmlablibsv2/einops-0.4.1-py3-none-any.whl
# !pip install -q ../input/mmlablibsv2/yapf-0.31.0-py2.py3-none-any.whl
# !pip install -q ../input/mmlablibsv2/addict-2.4.0-py3-none-any.whl
# !pip install -q ../input/mmlablibsv2/terminaltables-3.1.0-py3-none-any.whl
# !pip install -q ../input/mmlablibsv2/mmcv_full-1.3.17-cp37-cp37m-linux_x86_64.whl

In [2]:
!pip uninstall -qy transformers
!pip uninstall -qy tokenizers
!pip install -q ../input/pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q ../input/pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q ../input/segmentation-models-pytorch-030/timm-0.5.4-py3-none-any.whl
!pip install -q ../input/segmentation-models-pytorch-030/segmentation_models_pytorch-0.3.0.dev0-py3-none-any.whl
!pip uninstall -qy transformers
!pip uninstall -qy tokenizers
!pip install -q ../input/uwmgiseg/dependencies/tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
!pip install ../input/uwmgiseg/dependencies/huggingface_hub-0.5.1-py3-none-any.whl
!pip install -q ../input/uwmgiseg/dependencies/transformers-4.18.0-py3-none-any.whl
!pip install -q ../input/uwmgiseg/dependencies/monai-0.8.1-202202162213-py3-none-any.whl
!python ../input/uwmgiseg/setup.py develop

[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Processing /kaggle/input/uwmgiseg/dependencies/huggingface_hub-0.5.1-py3-none-any.whl
Installing collected packages: huggingface-hub
  Attempting uninstall: huggingface-hub
    Found 

# 📚 Import Libraries 

In [3]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd

# visualization
import cv2
import cupy as cp
import gc
import matplotlib.pyplot as plt

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
from torch.cuda.amp import autocast
import torch.nn.functional as F

import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

# Warnings
import warnings
warnings.filterwarnings("ignore")

from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)

from segmentation_models_pytorch.base.modules import Activation
from segmentation_models_pytorch.base import modules as md
from segmentation_models_pytorch.decoders.deeplabv3.decoder import ASPP, SeparableConv2d

INFO: Pandarallel will run on 2 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [4]:
import sys
sys.path.append("../input/uwmgiseg")
import cvcore
from cvcore.config import get_cfg
from cvcore.modeling.meta_arch import build_model

# 🔨 Utility

In [5]:
def mask2rle(msk):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    msk    = cp.array(msk)
    pixels = msk.flatten()
    pad    = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs   = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def read_image(name):
    img = cv2.imread(name, cv2.IMREAD_ANYDEPTH) / 65535.0
    img = 255 * ((img - img.min()) / (img.max() - img.min()))
    img = img.astype(np.uint8)
    return img

def df_preprocessing(df, globbed_file_list):
    """The preprocessing steps applied to get column information"""
    # 1. Get Case-ID as a column (str and int)
    df["case_id_str"] = df["id"].apply(lambda x: x.split("_", 2)[0])
    df["case_id"] = df["id"].apply(lambda x: int(x.split("_", 2)[0].replace("case", "")))

    # 2. Get Day as a column
    df["day_num_str"] = df["id"].apply(lambda x: x.split("_", 2)[1])
    df["day_num"] = df["id"].apply(lambda x: int(x.split("_", 2)[1].replace("day", "")))

    # 3. Get Slice Identifier as a column
    df["slice_id"] = df["id"].apply(lambda x: x.split("_", 2)[2])

    # 4. Get full file paths for the representative scans
    df["_partial_ident"] = (
        globbed_file_list[0].rsplit("/", 4)[0]
        + "/"
        + df["case_id_str"]  # /kaggle/input/uw-madison-gi-tract-image-segmentation/train/
        + "/"
        + df["case_id_str"]  # .../case###/
        + "_"
        + df["day_num_str"]
        + "/scans/"  # .../case###_day##/
        + df["slice_id"]
    )  # .../slice_####
    _tmp_merge_df = pd.DataFrame(
        {
            "_partial_ident": [x.rsplit("_", 4)[0] for x in globbed_file_list],
            "f_path": globbed_file_list,
        }
    )
    df = df.merge(_tmp_merge_df, on="_partial_ident").drop(columns=["_partial_ident"])

    # 5. Get slice dimensions from filepath (int in pixels)
    df["slice_h"] = df["f_path"].apply(lambda x: int(x[:-4].rsplit("_", 4)[1]))
    df["slice_w"] = df["f_path"].apply(lambda x: int(x[:-4].rsplit("_", 4)[2]))

    # 6. Pixel spacing from filepath (float in mm)
    df["px_spacing_h"] = df["f_path"].apply(lambda x: float(x[:-4].rsplit("_", 4)[3]))
    df["px_spacing_w"] = df["f_path"].apply(lambda x: float(x[:-4].rsplit("_", 4)[4]))

    # 7. Reorder columns to the a new ordering (drops class and segmentation as no longer necessary)
    new_col_order = [
        "id",
        "f_path",
        "slice_h",
        "slice_w",
        "px_spacing_h",
        "px_spacing_w",
        "case_id_str",
        "case_id",
        "day_num_str",
        "day_num",
        "slice_id",
    ]
    new_col_order = [_c for _c in new_col_order if _c in df.columns]
    df = df[new_col_order]
    return df


def get_nearby_slices(id_, case_id_str, day_num_str, case_length, num_slices=3, num_strides=1):
    slice_idx = int(id_.split("_")[-1])
    get_idxs = np.arange(slice_idx - num_slices//2 * num_strides, 
                         slice_idx + num_slices//2 * num_strides + 1, num_strides) # -7 -5 -3 -1 1 3 5 7 9
    
    min_idx = 2 if slice_idx%2 == 0 else 1
    if case_length % 2 == 0:
        max_idx = case_length if slice_idx%2 == 0 else case_length - 1
    else:
        max_idx = case_length if slice_idx%2 != 0 else case_length - 1

    get_idxs = np.clip(get_idxs, min_idx, max_idx)
    get_ids = [f"{case_id_str}_{day_num_str}_slice_{slice_idx:04d}" for slice_idx in get_idxs]
    return get_ids

## Test

In [6]:
DATA_DIR = "/kaggle/input/uw-madison-gi-tract-image-segmentation/"
TEST_DIR = os.path.join(DATA_DIR, "test")
TRAIN_DIR = os.path.join(DATA_DIR, "train")
SUB_CSV = os.path.join(DATA_DIR, "sample_submission.csv")

In [7]:
sub_df = pd.read_csv(SUB_CSV)

if not len(sub_df):
    # Infer on train cases
    debug = True 
    sub_df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"))
    sub_df = sub_df.drop(columns=['class','segmentation']).drop_duplicates()
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/train/**/*png', recursive=True)
    sub_df = df_preprocessing(sub_df, paths)
    cases = sub_df["case_id_str"].unique()[:1]
    sub_df = sub_df[sub_df["case_id_str"].isin(cases)].reset_index(drop=True)
else:
    debug = False
    sub_df = sub_df.drop(columns=['class','predicted']).drop_duplicates()
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/test/**/*png',recursive=True)
    sub_df = df_preprocessing(sub_df, paths)
    sub_df = sub_df.reset_index(drop=True)

## TH Model

In [8]:
# sys.path.append("../input/uwsegv2")
# from segmentation_models_pytorch.base import ClassificationHead as TClassificationHead
# from giseg.cvcore.modeling.backbone import TimmUniversalEncoder as TTimmUniversalEncoder
# from giseg.cvcore.modeling.backbone import mit_PLD_b4
# from giseg.cvcore.modeling.heads import UnetDecoder as TUnetDecoder
# from giseg.cvcore.modeling.heads import ASPPHead
# from giseg.cvcore.modeling.heads import SegmentationHead as TSegmentationHead

# class TimmUNetASPP(nn.Module):
#     def __init__(self, arch, img_size, num_slices, num_strides):
#         super(TimmUNetASPP, self).__init__()
#         num_slices = 31
#         self.encoder = TTimmUniversalEncoder(
#             arch,
#             pretrained=False,
#             in_channels=num_slices,
#             drop_path_rate=None,
#             img_size=img_size[0],
#         )
#         encoder_channels = list(self.encoder.out_channels)
#         if len(encoder_channels) == 5 or  len(encoder_channels) == 6: 
#             common_stride = 2
#         else:
#             common_stride = 1
#         decoder_channels = [2048, 1024, 512, 256]
#         n_blocks = len(decoder_channels)
#         num_classes = 3

#         self.decoder = TUnetDecoder(
#             encoder_channels,
#             decoder_channels,
#             n_blocks=n_blocks,
#             center=False,
#             attention_type='scse',
#             norm="BN",
#             act="relu",
#         )
#         self.segmentation_head = TSegmentationHead(
#             in_channels=decoder_channels[-1],
#             out_channels=num_classes,
#             upsampling=common_stride,
#         )
#         self.aux_decoder = ASPPHead(
#                             encoder_channels = encoder_channels,
#         )
        
#         self.segmentation_head_aux = TSegmentationHead(
#             in_channels=decoder_channels[-1],
#             out_channels=num_classes,
#             upsampling=4,
#         )
#         self.classification_head = TClassificationHead(
#             in_channels=self.encoder.out_channels[-1], classes=num_classes
#         )
#         self._add_hausdorff = False

#     @autocast()
#     def forward(self, images, gt_masks=None, image_sizes=None):
#         features = self.encoder(images)
#         decoder_output = self.decoder(*features)
#         masks = self.segmentation_head(decoder_output)
#         ## No need to infer this auxilary branch 
# #         aux_out = self.aux_decoder(*features) 
# #         aux_mask = self.segmentation_head_aux(aux_out)
#         if self.training:
#             losses = seg_criterion(masks, gt_masks, self._add_hausdorff)
#             losses.update(seg_criterion(aux_mask,gt_masks, self._add_hausdorff, weights=[1.,1.], aux=True))
#             gt_classes = (gt_masks.sum((2, 3)) > 0).float()
#             cls_logits = self.classification_head(features[-1])
#             cls_loss = cls_criterion(cls_logits, gt_classes)
#             losses.update({"bce_cls": cls_loss})
#             return losses
#         else:
#             # masks = masks.view(masks.shape[0], -1, 3, masks.shape[2], masks.shape[3])
#             # masks = masks[:, masks.shape[1] // 2, ...]
# #             return (masks + aux_mask)/2
#             return torch.sigmoid(masks)

# class TimmssFormerASPP(nn.Module):
#     def __init__(self, img_size, num_slices, num_strides):
#         super(TimmssFormerASPP, self).__init__()

#         if num_strides == 1:
#             num_slices = num_slices // num_strides + 1
#         else:
#             num_slices = num_slices
#         num_classes = 3

#         self.encoderdecoder = mit_PLD_b4(class_num=num_classes)
#         self.segmentation_head = TSegmentationHead(
#             in_channels=128,
#             out_channels=num_classes,
#             upsampling=4, 
#         )
#         self.classification_head = TClassificationHead(
#             in_channels=512, classes=num_classes
#         )
#         self._add_hausdorff = False

#     @autocast()
#     def forward(self, images, gt_masks=None, image_sizes=None):
#         features, masks = self.encoderdecoder(images)
#         if self.training:
#             losses = seg_criterion(masks, gt_masks, self._add_hausdorff)
#             gt_classes = (gt_masks.sum((2, 3)) > 0).float()
#             cls_logits = self.classification_head(features[-1])
#             cls_loss = cls_criterion(cls_logits, gt_classes)
#             losses.update({"bce_cls": cls_loss})
#             return losses
#         else:
#             # masks = masks.view(masks.shape[0], -1, 3, masks.shape[2], masks.shape[3])
#             # masks = masks[:, masks.shape[1] // 2, ...]
#             # return masks 
#             return torch.sigmoid(masks)

In [9]:
class TimmUniversalEncoder(nn.Module):
    def __init__(
        self,
        name,
        pretrained=False,
        in_channels=3,
        drop_path_rate=0.0,
        depth=5,
        output_stride=32,
        img_size=224,
    ):
        super().__init__()
        kwargs = dict(
            in_chans=in_channels,
            features_only=True,
            pretrained=pretrained,
            drop_path_rate=drop_path_rate,
            img_size=img_size,
        )
        kwargs.pop("img_size")
        self.model = timm.create_model(name, **kwargs)
        if name.startswith('convnext'):
            old_conv = self.model.stages_3.downsample[1]
            old_in, old_out = old_conv.in_channels, old_conv.out_channels
            self.model.stages_3.downsample[1] = nn.Conv2d(
                old_in, old_out, kernel_size=1, stride=1
            )
            self.model.stages_3.downsample[1].weight.data = old_conv.weight.mean(
                dim=(2, 3), keepdim=True
            )
            self.model.stages_3.downsample[1].bias.data = old_conv.bias
        self._out_channels = [
            in_channels,
        ] + self.model.feature_info.channels()
        self._depth = depth
        self._output_stride = output_stride

    def forward(self, x):
        features = self.model(x)
        features = [
            x,
        ] + features
        return features

    @property
    def out_channels(self):
        return self._out_channels

    @property
    def output_stride(self):
        return min(self._output_stride, 2 ** self._depth)

In [10]:
class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        activation = Activation(activation)
        super().__init__(conv2d, upsampling, activation)

class SegmentationHeadDouble(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, activation_func=None, upsampling_scale=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling_scale) if upsampling_scale > 1 else nn.Identity()
        activation = Activation(activation_func)
        conv2d_2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling_2 = nn.UpsamplingBilinear2d(scale_factor=upsampling_scale) if upsampling_scale > 1 else nn.Identity()
        activation_2 = Activation(activation_func)
        super().__init__(conv2d, upsampling, activation, conv2d_2, upsampling_2, activation_2)


class ClassificationHead(nn.Sequential):
    def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
        if pooling not in ("max", "avg"):
            raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
        pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
        flatten = nn.Flatten()
        dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
        linear = nn.Linear(in_channels, classes, bias=True)
        activation = Activation(activation)
        super().__init__(pool, flatten, dropout, linear, activation)

In [11]:
class DeepLabV3PlusDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        out_channels=256,
        atrous_rates=(12, 24, 36),
        output_stride=16,
    ):
        super().__init__()
        if output_stride not in {8, 16}:
            raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride))

        self.out_channels = out_channels
        self.output_stride = output_stride

        self.aspp = nn.Sequential(
            ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
            SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

        scale_factor = 2 if output_stride == 8 else 4
        self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)

        highres_in_channels = encoder_channels[-4]
        highres_out_channels = 48  # proposed by authors of paper
        self.block1 = nn.Sequential(
            nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(highres_out_channels),
            nn.ReLU(),
        )
        self.block2 = nn.Sequential(
            SeparableConv2d(
                highres_out_channels + out_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, *features):
        aspp_features = self.aspp(features[-1])
        aspp_features = self.up(aspp_features)
        high_res_features = self.block1(features[-4])
        concat_features = torch.cat([aspp_features, high_res_features], dim=1)
        fused_features = self.block2(concat_features)
        return fused_features

    
class DeepLabV3PlusDecoderFix(DeepLabV3PlusDecoder):
    def __init__(self, encoder_channels):
        super().__init__(encoder_channels,)
        aspp_channels = 256
        self.block3 = nn.Sequential(
            nn.Conv2d(aspp_channels, aspp_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(aspp_channels),
            nn.ReLU(),
        )
        self.up_2 = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, *features):
        aspp_features = self.aspp(features[-1])
        aspp_features = self.up(aspp_features)
        aspp_features = self.block3(aspp_features)
        aspp_features = self.up_2(aspp_features)
        high_res_features = self.block1(features[-4])
        concat_features = torch.cat([aspp_features, high_res_features], dim=1)
        fused_features = self.block2(concat_features)
        return fused_features
    

class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None, scale=True):
        if scale:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)
        
        
class UnetPlusPlusDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels=(256, 128, 64, 32, 16),
        n_blocks=5,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        self.in_channels = [head_channels] + list(decoder_channels[:-1])
        self.skip_channels = list(encoder_channels[1:]) + [0]
        self.out_channels = decoder_channels
        if center:
            self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)

        blocks = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(layer_idx + 1):
                if depth_idx == 0:
                    in_ch = self.in_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
                    out_ch = self.out_channels[layer_idx]
                else:
                    out_ch = self.skip_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1 - depth_idx)
                    in_ch = self.skip_channels[layer_idx - 1]
                blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
        blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(
            self.in_channels[-1], 0, self.out_channels[-1], **kwargs
        )
        self.blocks = nn.ModuleDict(blocks)
        self.depth = len(self.in_channels) - 1

    def forward(self, *features):
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder
        # start building dense connections
        dense_x = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(self.depth - layer_idx):
                if layer_idx == 0:
                    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](features[depth_idx], features[depth_idx + 1])
                    dense_x[f"x_{depth_idx}_{depth_idx}"] = output
                else:
                    dense_l_i = depth_idx + layer_idx
                    cat_features = [dense_x[f"x_{idx}_{dense_l_i}"] for idx in range(depth_idx + 1, dense_l_i + 1)]
                    cat_features = torch.cat(cat_features + [features[dense_l_i + 1]], dim=1)
                    dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[f"x_{depth_idx}_{dense_l_i}"](
                        dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features
                    )
        dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](dense_x[f"x_{0}_{self.depth-1}"])
        return dense_x[f"x_{0}_{self.depth}"]


class UnetPlusPlusDecoderFix(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels=(256, 128, 64, 32),
        n_blocks=4,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        self.in_channels = [head_channels] + list(decoder_channels[:-1])
        self.skip_channels = list(encoder_channels[1:]) + [0]
        self.out_channels = decoder_channels
        if center:
            self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)

        blocks = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(layer_idx + 1):
                if depth_idx == 0:
                    in_ch = self.in_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
                    out_ch = self.out_channels[layer_idx]
                else:
                    out_ch = self.skip_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1 - depth_idx)
                    in_ch = self.skip_channels[layer_idx - 1]
                blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
        blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(
            self.in_channels[-1], 0, self.out_channels[-1], **kwargs
        )
        self.blocks = nn.ModuleDict(blocks)
        self.depth = len(self.in_channels) - 1

    def forward(self, *features):
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder
        # start building dense connections
        dense_x = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(self.depth - layer_idx):
                if layer_idx == 0:
                    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](features[depth_idx], features[depth_idx + 1], depth_idx >= 1)
                    dense_x[f"x_{depth_idx}_{depth_idx}"] = output
                else:
                    dense_l_i = depth_idx + layer_idx
                    cat_features = [dense_x[f"x_{idx}_{dense_l_i}"] for idx in range(depth_idx + 1, dense_l_i + 1)]
                    cat_features = torch.cat(cat_features + [features[dense_l_i + 1]], dim=1)
                    dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[f"x_{depth_idx}_{dense_l_i}"](
                        dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features
                    )
        dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](dense_x[f"x_{0}_{self.depth-1}"])
        return dense_x[f"x_{0}_{self.depth}"]

In [12]:
class BaselineSegTimm(nn.Module):
    def __init__(self, index, num_stride):
        super(BaselineSegTimm, self).__init__()

        self.encoder = TimmUniversalEncoder(
            CFG.ENCODER[index],
            in_channels=num_stride,
            img_size=CFG.img_size[index][0],
        )

        with torch.no_grad():
            dummy_inputs = torch.randn(2, num_stride, *CFG.img_size[index])
            out = self.encoder(dummy_inputs)
            common_stride = CFG.img_size[index][0] // out[1].shape[2]
        encoder_channels = self.encoder.out_channels
        num_classes = 3

        if CFG.ARCH[index] == "DeepLabV3Plus":
            self.decoder = DeepLabV3PlusDecoder(
                encoder_channels=encoder_channels,
            )
        elif CFG.ARCH[index] == "DeepLabV3PlusFix":
            self.decoder = DeepLabV3PlusDecoderFix(
                encoder_channels=encoder_channels,
            )
        elif CFG.ARCH[index] == "UnetPlusPlus":
            self.decoder = UnetPlusPlusDecoder(
                encoder_channels=encoder_channels,
            )
        elif CFG.ARCH[index] == "UnetPlusPlusFix":
            self.decoder = UnetPlusPlusDecoderFix(
                encoder_channels=encoder_channels,
            )

        with torch.no_grad():
            out = self.decoder(*out)
        
        if CFG.ARCH[index] == "UnetPlusPlus":
            self.segmentation_head = SegmentationHead(
                    in_channels=out.shape[1],
                    out_channels=num_classes,
                    activation=None,
                    kernel_size=3,
                    upsampling=1,
                )
        elif CFG.ARCH[index] == "UnetPlusPlusFix":
            self.segmentation_head = SegmentationHead(
                    in_channels=out.shape[1],
                    out_channels=num_classes,
                    activation=None,
                    kernel_size=3,
                    upsampling=common_stride//2,
                )
        else:
            if CFG.ENCODER[index].startswith('tf_efficientnet') or CFG.ENCODER[index].startswith('ecaresnet'):
                self.segmentation_head = SegmentationHeadDouble(
                    in_channels=out.shape[1],
                    out_channels=num_classes,
                    activation_func=None,
                    kernel_size=3,
                    upsampling_scale=common_stride,
                )
            else:
                self.segmentation_head = SegmentationHead(
                    in_channels=out.shape[1],
                    out_channels=num_classes,
                    activation=None,
                    kernel_size=3,
                    upsampling=common_stride,
                )

    @autocast()
    def forward(self, images, gt_masks=None):
        features = self.encoder(images)
        decoder_output = self.decoder(*features)
        masks = self.segmentation_head(decoder_output)
        return torch.sigmoid(masks)

In [13]:
class CFG:
    ARCH = ['DeepLabV3Plus', 'DeepLabV3PlusFix', 
#             'DeepLabV3PlusFix',
            'Unet',
            'UnetPlusPlusFix', 'UnetPlusPlus',
            'Unet',
            'UnetPlusPlus'
    ]
    ENCODER = ['convnext_xlarge_in22ft1k', 'tf_efficientnetv2_l', 
#                'ecaresnet269d',
               'convnext_base',
               'convnext_xlarge_in22ft1k', 'tf_efficientnetv2_l',
               'convnext_base',
               'ecaresnet269d'
    ]
    WEIGHTS = [
       '../input/newgiseg/deeplabv3plus_convnext_xlarge_17_608_fold-1_e28.pth',
       '../input/uwmgiweights/deeplabv3plus_v2l_17_608_fold-1_e28.pth', 
#        '../input/uwmgiweights/deeplabv3plus_ecaresnet269d_17_608_long_fold-1_e98.pth',
       '../input/uwmgiseg/weights/convnext_base_acs_unet_epoch30.pth',
       '../input/uwmgiweights/unetplus_convnext_xlarge_fold-1_e28.pth',
       '../input/uwmgiweights/unetplus_v2l_fold-1_e28.pth',
       '../input/uwmgiseg/weights/convnext_base_unet_epoch10.pth',
       '../input/uwmgiweights/unetplus_ecaresnet269d_fold-1_e98.pth'
              ]
    NUM_CLASSES = 3
    

    img_size = [[640, 640], [640, 640], 
#                 [640, 640], 
                [512, 512], [640, 640], [640, 640], [512, 512],
                [640, 640]]
    slices = [17, 17, 
#               17, 
              33, 9, 9, 5, 9]
    strides = [2, 2, 
#                3, 
               1, 2, 3, 1, 2]
#     THRESHOLDS = [0.3, 0.3, 0.4]
    THRESHOLDS = [0.3, 0.3, 0.375]

In [14]:
class GISegDataset(Dataset):
    def __init__(self, df, target_size):
        super().__init__()
        resized_h, resized_w = target_size
        df["case_day_str"] = df["case_id_str"] + "_" + df["day_num_str"]
        self.cases_length = df["case_day_str"].value_counts().to_dict()
        self.images_dict = {id: f_path for id, f_path in zip(df["id"].values, df["f_path"].values)}
        self.aug512 = A.Compose([A.Resize(512, 512), ToTensorV2()])
        self.aug608 = A.Compose([A.Resize(608, 608), ToTensorV2()])
        self.aug640 = A.Compose([A.Resize(640, 640), ToTensorV2()])
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        info = self.df.iloc[idx]
        center_id = info["id"]
        case_length = self.cases_length.get(info["case_day_str"])
        case_id = info["case_id_str"]
        day = info["day_num_str"]
        
        ids172 = get_nearby_slices(center_id, case_id, day, case_length,
                                   num_slices=17, num_strides=2)
#         ids173 = get_nearby_slices(center_id, case_id, day, case_length,
#                                    num_slices=17, num_strides=3)
        ids331 = get_nearby_slices(center_id, case_id, day, case_length,
                                   num_slices=33, num_strides=1)
        ids92 = get_nearby_slices(center_id, case_id, day, case_length,
                                  num_slices=9, num_strides=2)
        ids93 = get_nearby_slices(center_id, case_id, day, case_length,
                                  num_slices=9, num_strides=3)        
#         img_ids = set(ids172 + ids173 + ids331 + ids92 + ids93)
        img_ids = set(ids172 + ids331 + ids92 + ids93)
        
        imgs_dict = {id_: read_image(self.images_dict.get(id_)) for id_ in img_ids}
        h, w = imgs_dict.get(ids172[0]).shape
        
        img172 = np.stack([imgs_dict.get(id_) for id_ in ids172], axis=-1)
        img172 = self.aug640(image=img172)["image"].float() / 255.
        
#         img173 = np.stack([imgs_dict.get(id_) for id_ in ids173], axis=-1)
#         img173 = self.aug640(image=img173)["image"].float() / 255.
        
        img331 = np.stack([imgs_dict.get(id_) for id_ in ids331], axis=-1)
        img331 = self.aug512(image=img331)["image"].float() / 255.
        img51 = img331[14:19]
        
        idx92 = []; idx93 = []
        for i in ids92:
            idx92.append(ids172.index(i))
        img92 = img172[idx92]
#         for i in ids93:
#             idx93.append(ids173.index(i))
#         img93 = img173[idx93]

        img93 = np.stack([imgs_dict.get(id_) for id_ in ids93], axis=-1)
        img93 = self.aug640(image=img93)["image"].float() / 255.
        
#         return img172, img173, img331, img92, img93, img51, center_id, h, w
        return img172, img331, img92, img93, img51, center_id, h, w

# 🔭 Inference

In [15]:
# all models
def load_weight(wt, model, cls=False):
    ckpt = torch.load(wt, "cpu")
    if not cls:
        ckpt["model"] = {k: v for k, v in ckpt["model"].items() if "classification_head" not in k}
    else:
        ckpt["model"] = {k: v for k, v in ckpt["model"].items()}
    model.load_state_dict(ckpt.pop("model"))
    print(wt, ckpt["best_metric"])
    del ckpt; gc.collect()
    model.eval()
    model = model.cuda()
    return model

all_models = []

for i in range(len(CFG.WEIGHTS)):
    if "resnest200_unet_aspp" in CFG.WEIGHTS[i]:
        model = TimmUNetASPP(CFG.ENCODER[i],CFG.img_size[i],CFG.slices[i], CFG.strides[i])
        model = load_weight(CFG.WEIGHTS[i], model, cls=True)
    elif "ssformer" in CFG.WEIGHTS[i]:
        model = TimmssFormerASPP(CFG.img_size[i],CFG.slices[i], CFG.strides[i])
        model = load_weight(CFG.WEIGHTS[i], model, cls=True)
    elif "convnext_base_unet" in CFG.WEIGHTS[i]:
        cfg = get_cfg()
        cfg.merge_from_file(f"../input/uwmgiseg/configs/cb-unet.yaml")
        cfg.MODEL.BACKBONE.PRETRAINED = False
        model = build_model(cfg)
        model = load_weight(CFG.WEIGHTS[i], model)
    elif "convnext_base_acs_unet" in CFG.WEIGHTS[i]:
        cfg = get_cfg()
        cfg.merge_from_file(f"../input/uwmgiseg/configs/cb-acs-unet.yaml")
        cfg.MODEL.BACKBONE.PRETRAINED = False
        model = build_model(cfg)
        model = load_weight(CFG.WEIGHTS[i], model)
    elif "segformer_b5" in  CFG.WEIGHTS[i]:
        cfg.merge_from_file(f"../input/uwmgiseg/configs/segb5.yaml")
        cfg.MODEL.BACKBONE.ARCH = "../input/uwmgiseg/segformer-b5-finetuned-ade-640-640/"
        cfg.MODEL.BACKBONE.PRETRAINED = False
        model = build_model(cfg)
        model = load_weight(CFG.WEIGHTS[i], model)
    else:
        model = BaselineSegTimm(i, CFG.slices[i])
        model = load_weight(CFG.WEIGHTS[i], model)
    all_models.append(model)
    del model

../input/newgiseg/deeplabv3plus_convnext_xlarge_17_608_fold-1_e28.pth 0
../input/uwmgiweights/deeplabv3plus_v2l_17_608_fold-1_e28.pth 0
../input/uwmgiseg/weights/convnext_base_acs_unet_epoch30.pth 0.13078680634498596
../input/uwmgiweights/unetplus_convnext_xlarge_fold-1_e28.pth 0
../input/uwmgiweights/unetplus_v2l_fold-1_e28.pth 0
../input/uwmgiseg/weights/convnext_base_unet_epoch10.pth 0.13846728205680847
../input/uwmgiweights/unetplus_ecaresnet269d_fold-1_e98.pth 0


In [16]:
def masks2rles(masks, ids):
    pred_strings = []; pred_ids = []; pred_classes = []
    
    for idx in range(len(masks)):
        mask = masks[idx]
        mask[0, :, :] = torch.where(mask[0, :, :] >= CFG.THRESHOLDS[0], 1, 0)
        mask[1, :, :] = torch.where(mask[1, :, :] >= CFG.THRESHOLDS[1], 1, 0)
        mask[2, :, :] = torch.where(mask[2, :, :] >= CFG.THRESHOLDS[2], 1, 0)
        mask = mask.to(torch.uint8).permute(1, 2, 0).cpu().numpy()
    
        rle = [None] * 3
        for class_idx in [0, 1, 2]:
            rle[class_idx] = mask2rle(mask[..., class_idx])

        pred_strings.extend(rle)
        pred_ids.extend([ids[idx]] * 3)
        pred_classes.extend(['large_bowel', 'small_bowel', 'stomach'])

    return pred_strings, pred_ids, pred_classes

IMAGES = []
MASKS = []

@torch.no_grad()
def inference(models, test_loader):
    pred_strings = []; pred_ids = []; pred_classes = []
    for imgs172, imgs331, imgs92, imgs93, imgs51, ids, heights, widths in tqdm(test_loader):
        imgs172 = imgs172.half().cuda(non_blocking=True)
        imgs331 = imgs331.half().cuda(non_blocking=True)
        imgs92 = imgs92.half().cuda(non_blocking=True)
        imgs93 = imgs93.half().cuda(non_blocking=True)
        imgs51 = imgs51.half().cuda(non_blocking=True)
        
        masks = [0.0] * len(ids)
        with autocast():
            for model_index, model in enumerate(models):
                if model_index <= 1:
                    out = model(imgs172)
                elif model_index == 2:
                    out = model(imgs331)
                elif model_index == 3:
                    out = model(imgs92)
                elif model_index == 4:
                    out = model(imgs93)
                elif model_index == 5:
                    out = model(imgs51)
                elif model_index == 6:
                    out = model(imgs92)
#                 out_h, out_w = out.size()[-2:]
#                 assert [out_h, out_w] == CFG.img_size[model_index]
                # Interpolation
                for idx, (id_, mask, height, width) in enumerate(zip(ids, out, heights, widths)):
                    mask = F.interpolate(mask.unsqueeze(0), size=(height.item(), width.item()), 
                                         mode='bilinear', align_corners=True)[0]
                    masks[idx] += mask / len(models)
        
        result = masks2rles(masks, ids)
        pred_strings.extend(result[0])
        pred_ids.extend(result[1])
        pred_classes.extend(result[2])
        del result
        gc.collect()
    return pred_strings, pred_ids, pred_classes

In [17]:
test_dataset = GISegDataset(sub_df, CFG.img_size[0])
test_loader = DataLoader(test_dataset, batch_size=4, 
                         num_workers=2, shuffle=False, pin_memory=True)
pred_strings, pred_ids, pred_classes = inference(
    all_models, test_loader)

100%|██████████| 108/108 [08:42<00:00,  4.84s/it]


In [18]:
# del IMAGES, MASKS
gc.collect()

21

# 📝 Submission

In [19]:
pred_df = pd.DataFrame({
    "id":pred_ids,
    "class":pred_classes,
    "predicted":pred_strings
})
if not debug:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
    del sub_df['predicted']
else:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')
    del sub_df['segmentation']
    sub_df = sub_df[sub_df['id'].apply(lambda x: any([x.startswith(case_id) for case_id in cases]))]

assert len(sub_df) == len(pred_df)
assert len(test_dataset.df) * 3 == len(pred_df)
sub_df = sub_df.merge(pred_df, on=['id','class'])

In [20]:
sub_df.to_csv('submission.csv',index=False)
print(sub_df.head(10))

                         id        class predicted
0  case123_day20_slice_0001  large_bowel          
1  case123_day20_slice_0001  small_bowel          
2  case123_day20_slice_0001      stomach          
3  case123_day20_slice_0002  large_bowel          
4  case123_day20_slice_0002  small_bowel          
5  case123_day20_slice_0002      stomach          
6  case123_day20_slice_0003  large_bowel          
7  case123_day20_slice_0003  small_bowel          
8  case123_day20_slice_0003      stomach          
9  case123_day20_slice_0004  large_bowel          
