In [6]:
import os
import os.path as osp
import sys
from functools import partial
from pathlib import Path

import albumentations.augmentations.geometric.functional as AGF
import cv2
import numpy as np
import pandas as pd
import rasterio as rio
import torch
from omegaconf import OmegaConf
from tqdm import tqdm


#####################################################################
#                       Change only this                            #
#####################################################################
EXPERIMENT = "unet_resnet34"
CKPT = "e28_t100_cmax_ema_0.746.pth"
DEVICE = "cuda"

DEBUG = True

BLOCK_SIZE = 512
NETWORK_SCALE = 1. / 3
# TODO:
# HUBMAP_SCALE = 1.0
# HPA_SCALE = 1.0
PAD_RATIO = 0.25
BATCH_SIZE = 6
THRESHOLD = 0.5
TTA = False

DATA_SOURCES = {'Hubmap', 'HPA'}
# DATA_SOURCES = {'HPA'}
ORGANS = {'kidney', 'prostate', 'largeintestine', 'spleen', 'lung'}


#####################################################################
#                         Do not change                             #
#####################################################################
HMIB_MODELS_DIR = "/home/ubuntu/asnorkin/background_matting/robustvideomatting/tmp/hmib/output/"
EXP_DIR = osp.join(HMIB_MODELS_DIR, EXPERIMENT)
SRC_DIR = osp.join(EXP_DIR, "src")

# Imports from src
sys.path.append(SRC_DIR)
import network  # From src
from infer import init_model, norm_2d


# Config and Model
CONFIG_FILE = osp.join(SRC_DIR, "configs/u.yaml")
MODEL_FILE = osp.join(EXP_DIR, "split_0/models", CKPT)


# Images and output paths
DATA_DIR = "/home/ubuntu/asnorkin/background_matting/robustvideomatting/tmp/hmib/input/hmib/"
TEST_IMAGES_DIR = osp.join(DATA_DIR, "test_images")
TRAIN_IMAGES_DIR = osp.join(DATA_DIR, "train_images")
TEST_CSV_FILE = osp.join(DATA_DIR, "test.csv")
TRAIN_CSV_FILE = osp.join(DATA_DIR, "train.csv")
OUTPUT_FILE = "/kaggle/working/submission.csv"
DUMMY_RLE = ""

In [7]:
#####################################################################
#                             Model                                 #
#####################################################################
def preprocess(batch):
    """Preprocessing
    
    Params
    ------
    batch: np.array of shape (batch, height, width, channels); batch of raw images.
    
    Returns
    -------
    X: torch.Tensor of shape (batch, channels, height, width); batch of preprocessed images.
    """
    X = torch.from_numpy(batch.transpose((0, 3, 1, 2))).float().to(DEVICE)
    X, mean, std = norm_2d(X, mean=cfg.AUGS.MEAN, std=cfg.AUGS.STD)
    return X


def tta_4rot90_encode(batch):
    return torch.concat([
        torch.rot90(batch, k=k, dims=[2, 3]) for k in range(4)
    ])


def tta_4rot90_decode(y):
    bs = y.shape[0] // 4
    
    y_ori = y[:bs]
    y_90 = y[bs: 2 * bs]
    y_180 = y[2 * bs: 3 * bs]
    y_270 = y[3 * bs:]
    
    return torch.stack([
        torch.rot90(yi, k=-i, dims=[2, 3]) 
        for i, yi in enumerate([y_ori, y_90, y_180, y_270])
    ]).mean(dim=0)


def tta_encode(batch):
    return torch.concat([
        tta_4rot90_encode(batch),
        tta_4rot90_encode(torch.flip(batch, dims=(2,))),
    ])


def tta_decode(y):
    bs = y.shape[0] // 8
    
    return torch.stack([
        tta_4rot90_decode(y[:bs * 4]),
        tta_4rot90_decode(torch.flip(y[bs * 4:], dims=(2,))),
    ]).mean(dim=0)


def infer(batch, model, threshold=None, sigmoid=True, tta=False):
    """Inference
    
    Params
    ------
    batch: torch.Tensor of shape (batch, channels, height, width)
        Batch of preprocessed images
    model: torch.nn.Module
        Model
    threshold: float
        Confidence threshold
    sigmoid: bool
        Apply sigmoid or not
    tta: bool
        Apply 4 flips TTA or not
    
    Returns
    -------
    yb: np.array of shape (batch, height, width, 1)
        Batch masks
    """
    ori_batch_size = batch.shape[0]
    
    # Extend batch
    if tta:
        batch = tta_encode(batch)
    
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=True):
            batch_pred = model({"xb": batch})
        
    yb = batch_pred["yb"].float()
    
    # Average predictions
    if tta:
        yb = tta_decode(yb)
    
    if sigmoid:
        yb.sigmoid_()
    
    if threshold is not None:
        yb = (yb > threshold)
    else:
        pass
#         yb = (yb * 255)

    yb = yb.cpu().numpy()
    yb = yb.transpose((0, 2, 3, 1))
    yb = yb.astype(np.uint8)
        
    return yb


def get_inferer(model, **inferer_kw):

    def inferer(batch):
        return infer(
            preprocess(batch), 
            model=model, 
            **inferer_kw,
        )
    
    return inferer


#####################################################################
#                            Blocks                                 #
#####################################################################
def generate_block_coords(H, W, block_size):
    h,w = block_size
    nYBlocks = (int)((H + h - 1) / h)
    nXBlocks = (int)((W + w - 1) / w)
    
    for X in range(nXBlocks):
        cx = X * h
        for Y in range(nYBlocks):
            cy = Y * w
            yield cy, cx, h, w
            
            
def pad_block(y, x, h, w, pad): 
    return np.array([y - pad, x - pad, h + 2 * pad, w + 2 * pad])


def crop(src, y, x, h, w): 
    return src[..., y: y + h, x: x + w]


def paste(src, block, y, x, h, w):
    src[..., y: y + h, x: x + w] = block
    
    
def paste_crop(src, part, block_cd, pad):
    H, W = src.shape[-2:]
    y, x, h, w = block_cd
    h, w = min(h, H - y), min(w, W - x)  
    part = crop(part, pad, pad, h, w)
    paste(src, part, *block_cd)
    
    
# TODO: check in public notebooks
def mask2rle(img):
    pixels = img.T.flatten()
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rescale(batch_img, scale):
    return torch.nn.functional.interpolate(
        batch_img, 
        scale_factor=(scale, scale), 
        mode='bilinear', 
        align_corners=False
    )

            
#####################################################################
#                      TiffReader class                             #
#####################################################################
class TiffReader:
    """Reads tiff files.

    If subdatasets are available, then use them, otherwise just handle as usual.
    """

    def __init__(self, path_to_tiff_file: str):
        self.ds = rio.open(path_to_tiff_file)
        self.subds_list = [rio.open(subds_path) for subds_path in self.ds.subdatasets]

    def read(self, window=None, boundless=True):
        """
        Returns
        -------
        output: np.array of shape (height, width, channels)
            Result image
        """
        
        # TODO: rescale window
        ds_kwargs = {}
        if window is not None: 
            ds_kwargs.update({'window': window, 'boundless': boundless})
            
        if self.is_subsets_avail:
            output = np.vstack(
                [ds.read(**ds_kwargs) for ds in self.subds_list])
        else:
            output = self.ds.read(**ds_kwargs)
            
        output = output.transpose((1, 2, 0))
        return output
    
    def read_block(self, y, x, h, w, boundless=True):
        return self.read(
            window=((y, y + h), (x, x + w)), 
            boundless=boundless
        )
    
    def read_batch(self):
        raise NotImplementedError
    
    @property
    def is_subsets_avail(self):
        return len(self.subds_list) > 0

    @property
    def shape(self):
        if self.is_subsets_avail:
            return self.subds_list[0].shape
        else:
            return self.ds.shape

    def __del__(self):
        del self.ds
        del self.subds_list
    
    def close(self):
        self.ds.close()
        for subds in self.subds_list:
            subds.close()
            

#####################################################################
#                   BatchedTiffReader class                         #
#####################################################################
class BatchedTiffReader(TiffReader):
    def __init__(
        self, 
        path_to_tiff_file: str, 
        block_size: int, 
        network_scale: float, 
        pad_ratio: float, 
        batch_size: int
    ):
        super().__init__(path_to_tiff_file)

        self.block_size = block_size
        self.network_scale = network_scale
        self.pad_ratio = pad_ratio
        self.batch_size = batch_size
        
        H, W = self.shape
        scaled_block_size = self.scaled_block_size
        self.blocks_coords = list(generate_block_coords(
            H, W, block_size=(scaled_block_size, scaled_block_size)
        ))
        self.next_block = 0

    def __len__(self):
        return int(np.ceil(self.total_blocks / self.batch_size))

    @property
    def scaled_block_size(self):
        return int(round(self.block_size / self.network_scale))

    @property
    def pad_size(self):
        return int(round(self.block_size * self.pad_ratio))

    @property
    def scaled_pad_size(self):
        return int(round(self.scaled_block_size * self.pad_ratio))

    @property
    def total_blocks(self):
        return len(self.blocks_coords)

    @property
    def inv_network_scale(self):
        return 1.0 / self.network_scale

    def has_next_block(self):
        return self.next_block < len(self.blocks_coords)

    def read_batch(self):
        if not self.has_next_block():
            return None

        batch_blocks, batch_coords = [], []
        for i in range(self.batch_size):
            if not self.has_next_block():
                break

            block_cd = self.blocks_coords[self.next_block]
            padded_block_cd = pad_block(*block_cd, self.scaled_pad_size)
            block = self.read_block(*padded_block_cd)
            block = AGF.scale(block, self.network_scale)
            batch_blocks.append(block)
            batch_coords.append(block_cd)

            self.next_block += 1

        return np.stack(batch_blocks), np.stack(batch_coords)

    def __iter__(self):
        return iter(self.read_batch, None)


#####################################################################
#                       Image Inference                             #
#####################################################################
def infer_image(image_file, inferer, debug=False):
    # Create Batched Reader
    image_reader = BatchedTiffReader(
        image_file, 
        block_size=BLOCK_SIZE, 
        network_scale=NETWORK_SCALE, 
        pad_ratio=PAD_RATIO, 
        batch_size=BATCH_SIZE
    )
    H, W = image_reader.shape

    # Infer batch by batch
    mask = np.zeros((1, H, W)).astype(np.float32 if debug else bool)
    for batch_blocks, batch_coords in image_reader:
        batch_masks = inferer(batch_blocks)
        for block_mask, block_cd in zip(batch_masks, batch_coords):
            block_mask = AGF.scale(block_mask, image_reader.inv_network_scale)
            block_mask = block_mask.transpose((2, 0, 1))
            paste_crop(mask, block_mask, block_cd, image_reader.scaled_pad_size)

    # Close the Reader
    image_reader.close()

    # Build the result
    return mask if debug else mask2rle(mask)


In [8]:
# Create model
cfg = OmegaConf.load(CONFIG_FILE)
cfg.MODEL.ENCODER.pretrained = False

model = init_model(cfg, MODEL_FILE, network, to_gpu=(DEVICE == "cuda"))
inferer = get_inferer(model=model, threshold=THRESHOLD, tta=TTA)

In [86]:
SPLIT = 0

images_dir = TRAIN_IMAGES_DIR
df = pd.read_csv(TRAIN_CSV_FILE)

if SPLIT is not None:
    split_indices = pd.read_csv(f"../input/splits/{SPLIT}.csv", header=None)
    split_indices = split_indices.iloc[:, 0].values
    df = df[df.index.isin(split_indices)]

result = []
for row in tqdm(df.itertuples(), total=len(df), desc="Inference"):
    rle = DUMMY_RLE
    if row.data_source in DATA_SOURCES and row.organ in ORGANS:
        image_file = osp.join(images_dir, f"{row.id}.tiff")
        rle = infer_image(image_file, inferer)
        
    df.loc[row.Index, "rle_pred"] = rle

    result.append({
        "id": row.id,
        "rle": rle
    })

result = pd.DataFrame(result)
# result.to_csv(OUTPUT_FILE, index=False)

Inference: 100%|██████████| 88/88 [00:52<00:00,  1.69it/s]


In [87]:
def row_dice(row):
    mask_gt = rle2mask(row.rle, shape=(row.img_height, row.img_width))
    mask_pred = rle2mask(row.rle_pred, shape=(row.img_height, row.img_width))
    return dice(mask_gt, mask_pred)


def dice(mask_gt, mask_pred, eps=1e-6):
    intersection = mask_gt * mask_pred
    return (2 * intersection.sum() + eps) / (mask_gt.sum() + mask_pred.sum() + eps)


def rle2mask(rle, shape):
    seq = rle.split()
    
    starts = np.array(list(map(int, seq[0::2])))
    lengths = np.array(list(map(int, seq[1::2])))
    assert len(starts) == len(lengths)
    
    ends = starts + lengths
    img = np.zeros((np.product(shape),), dtype=np.uint8)
    for begin, end in zip(starts, ends):
        img[begin:end] = 1

    img.shape = shape
    return img


In [88]:
for row in tqdm(df.itertuples(), total=len(df)):
    df.loc[row.Index, "dice"] = row_dice(row)

In [90]:
df.dice.mean()

0.7353670442811283

In [93]:
for source in DATA_SOURCES:
    source_dice = df[df.data_source == source].dice.mean()
    print(f"{source} dice: {source_dice:.3f}")

HPA dice: 0.735
Hubmap dice: nan


In [94]:
for organ in ORGANS:
    organ_dice = df[df.organ == organ].dice.mean()
    print(f"{organ} dice: {organ_dice:.3f}")

spleen dice: 0.716
lung dice: 0.142
prostate dice: 0.818
kidney dice: 0.933
largeintestine dice: 0.892
