In [1]:
from time import time
from glob import glob
import tifffile as tiff
import pandas as pd
import numpy as np
import cv2

import sys
sys.path.append('..')
import util

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_dir = '/root/data'
model_pth = './bin/models/edge_model_trained5.pt'
scan_folders = [
    f'{data_dir}/train/kidney_2/',
]

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

def id_from_pth(pth: str):
    parts = pth.split("/")[-3:]
    parts.pop(1)
    return "_".join(parts)[:-4]

def proprocess(_scan: np.ndarray):
    scan = _scan.astype(np.float32)
    smin, smax = np.min(scan), np.max(scan)
    scan = (255 * (scan - smin) / (smax - smin)).astype(np.uint8)
    scan = 255 - scan
    clahe = cv2.createCLAHE(clipLimit=40.0, tileGridSize=(8, 8))
    return clahe.apply(scan)

def load_slice(pth, preprocess_fn):
    return torch.tensor(
        proprocess(tiff.imread(pth))
    )


def fill_outline(outline: torch.Tensor) -> torch.Tensor:
    outline = (outline.squeeze(0) * 255).numpy().astype(np.uint8)

    contours, _ = cv2.findContours(outline, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    filled_image = np.zeros_like(outline)
    cv2.drawContours(filled_image, contours, -1, (255), thickness=cv2.FILLED)
    return torch.Tensor(filled_image).bool().unsqueeze(0)

class Patch2p5D(nn.Module):
    def __init__(self, model: util.UNet3P):
        super().__init__()
        self.model = model
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4])
        return self.model(z)[-1]


class ScanInference2p5D(nn.Module):
    def __init__(self, patch_fn, batch_size: int, quick: bool = False):
        super().__init__()
        self.patch_fn = patch_fn
        self.batch_size = batch_size
        
        self.patch_size = (1, 256, 256)
        self.perms = torch.Tensor([
            (0, 1, 2),
            (1, 0, 2),
            (2, 1, 0)

            # (0, 2, 1),
            # (1, 0, 2),
            # (1, 2, 0),
            # (2, 0, 1),
            #(2, 1, 0), 5 perms to make self.pass_max a nice number
        ]).int()
        if quick:
            self.perms = self.perms[:1]
        
        self.register_buffer('pass_max', torch.tensor(255 / (len(self.perms) * 4)))
    
    def _forward(self, scan: torch.Tensor, device: torch.device = 'cpu') -> torch.Tensor:
        agg_pred = torch.zeros_like(scan, dtype=torch.uint8)
        scan_loader = DataLoader(
            util.SweepCube(scan, self.patch_size, stride=(1, self.patch_size[1]//2, self.patch_size[2]//2)),
            batch_size=self.batch_size,
            shuffle=True
        )

        for x, positions in scan_loader:
            x = x.to(device).float()
            pred = F.sigmoid(self.patch_fn(x)).cpu()
            for p, pos in zip(pred, positions):
                agg_pred[
                    :,
                    pos[0] + self.patch_size[0]//2,
                    pos[1]:pos[1] + self.patch_size[1],
                    pos[2]:pos[2] + self.patch_size[2],
                ] += ((p.squeeze(1) * self.pass_max).round()).byte()#(p.squeeze(1) * 255).int
        
        return agg_pred

    def forward(self, scan: torch.Tensor, device: torch.device = 'cpu') -> torch.Tensor:
        agg_pred = torch.zeros_like(scan, dtype=torch.uint8)

        for perm in self.perms:
            out = self._forward(scan.permute(0, *perm+1), device).permute(0, *torch.argsort(perm)+1)
            agg_pred += out
            del out
        return agg_pred

In [3]:
patch_size = 1,256,256
train_data = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    samples=[
        "/train/kidney_2",
        # "/train/kidney_3_dense",
        # "/train/kidney_3_sparse"
    ]
)

Loading /train/kidney_2/images from cache


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6005/ (Press CTRL+C to quit)


Loading /train/kidney_2/labels from cache


In [4]:
patcher = Patch2p5D(
    util.UNet3P(
        in_f=1,
        layers=[32, 64, 128, 256, 512],
        block_depth=4,
        connect_depth=24,
        conv=util.nn.Conv2DNormed,
        pool_fn=nn.MaxPool2d,
        resize_kernel=(2,2),
        upsample_mode='bilinear',
        norm_fn=nn.InstanceNorm2d,
    )
).to(device)
patcher.model.load_state_dict(torch.load(model_pth, map_location=device))
patcher.requires_grad_(False)
patcher.eval()

inference = ScanInference2p5D(
    patcher,
    batch_size=64,
    quick=False
)

In [5]:
th = 0.5
threshold = 50#int(255 * th)

submission_list = []
for scan_fn in scan_folders:
    slices, ids = [], []
    print(f"loading scan {scan_fn}")
    for pth in sorted(glob(scan_fn + "images/*.tif")):
        slices.append(load_slice(pth, proprocess))
        ids.append(id_from_pth(pth))
    if len(slices) == 0:
        continue
    scan = torch.stack(slices).unsqueeze(0)

    print("doing inference...")
    pmask = inference(scan, device).squeeze(0)

    print("aggregating slice rle encodings...")
    for id, smask in zip(ids, pmask):
        submission_list.append(
            pd.DataFrame(data={
                'id'  : id,
                'rle' : rle_encode((smask > threshold).numpy()),
            },index=[0])
        )
    # del scan
    # del pmask
    # del slices
    # del ids

print("saving aggregate dataframe of rle encodings...")
submission_df = pd.concat(submission_list)
submission_df.to_csv('submission.csv', index=False)

loading scan /root/data/train/kidney_2/
doing inference...
aggregating slice rle encodings...
saving aggregate dataframe of rle encodings...


In [6]:
torch.save(pmask, 'pmask.pt')

In [6]:
def _fill_all_axis(seg: torch.Tensor) -> torch.Tensor:
    for i in range(seg.shape[0]):
        seg[i] = fill_outline(seg[i])
    for i in range(seg.shape[1]):
        seg[:,i] = fill_outline(seg[:,i])
    for i in range(seg.shape[2]):
        seg[:,:,i] = fill_outline(seg[:,:,i])
    return seg

def fill_all_axis(seg: torch.Tensor, iters: int) -> torch.Tensor:
    for i in range(iters):
        _fill_all_axis(seg)
    return seg

In [20]:
seg = torch.load('pmask.pt') > 50

In [21]:
seg = fill_all_axis(seg, 3)

In [None]:

seg.dtype

In [11]:
pmask = seg

In [23]:
# exp w/ threshold
submission_list = []
for id, smask in zip(ids, seg):
    submission_list.append(
        pd.DataFrame(data={
            'id'  : id,
            'rle' : rle_encode(((smask)).numpy()),
        },index=[0])
    )
submission_df = pd.concat(submission_list)
submission_df.to_csv('submission.csv', index=False)

In [12]:
util.Display.view([
    util.Display(pmask.unsqueeze(0) > 50),
    util.Display(train_data.labels[0]),
    # util.Display(train_data.scans[0]),
])

interactive(children=(IntSlider(value=0, description='i', max=2216), Output()), _dom_classes=('widget-interact…

In [11]:
submission_df = pd.read_csv('submission.csv')

def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

# slices = []
# for row in submission_df.itertuples():
#     slices.append()
pmask2 = torch.stack(
    [torch.tensor(rle_decode(row.rle, (1303, 912)))
    for row in submission_df.itertuples()]
).unsqueeze(0).bool()

In [14]:
pmask2.shape

torch.Size([1, 2279, 1303, 912])

In [12]:
util.Display.view([
    util.Display(pmask2),
])

interactive(children=(IntSlider(value=0, description='i', max=2278), Output()), _dom_classes=('widget-interact…