In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
from typing import Union
import cv2

import sys
sys.path.append('..')
import util
import lutil

class Patch2p5D(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4])
        return self.model(z)[-1]
        # return torch.stack([self.model(z)[-1] for _ in range(5)], dim=1).mean(dim=1)

class ScanInference2p5D(nn.Module):
    def __init__(self, patch_fn, batch_size: int, quick: bool = False):
        super().__init__()
        self.patch_fn = patch_fn
        self.batch_size = batch_size
        
        self.patch_size = (8, 256, 256)
        self.perms = torch.Tensor([
            (0, 1, 2),
            (0, 2, 1),
            (1, 0, 2),
            (1, 2, 0),
            (2, 0, 1),
            (2, 1, 0),
        ]).int()
        if quick:
            self.perms = self.perms[:1]
    
    def _forward(self, scan: torch.Tensor, device: torch.device = 'cpu') -> torch.Tensor:
        agg_pred = torch.zeros_like(scan).float()
        scan_loader = DataLoader(
            util.SweepCube(scan, self.patch_size, stride=(1, self.patch_size[1], self.patch_size[2])),
            batch_size=self.batch_size,
            shuffle=True
        )

        for x, positions in scan_loader:
            x = x.to(device).float()
            pred = F.sigmoid(self.patch_fn(x)).cpu()
            for p, pos in zip(pred, positions):
                agg_pred[
                    :,
                    pos[0] + self.patch_size[0]//2,
                    pos[1]:pos[1] + self.patch_size[1],
                    pos[2]:pos[2] + self.patch_size[2],
                ] += p.squeeze(1)
        
        return agg_pred

    def forward(self, scan: torch.Tensor, device: torch.device = 'cpu') -> torch.Tensor:
        agg_pred = torch.zeros_like(scan).float()

        for perm in tqdm(self.perms):
            out = self._forward(scan.permute(0, *perm+1), device).permute(0, *torch.argsort(perm)+1)
            # agg_pred += self._forward(scan.permute(0, *perm+1), device).permute(0, *torch.argsort(perm)+1)
            agg_pred += out
            del out
        return agg_pred / len(self.perms)

def wrapped_inference(scan_pth: str, model, batch_size: int, quick: bool = False):
    inference = ScanInference2p5D(Patch2p5D(model), batch_size, quick)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
patch_size = 8,256,256
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    samples=[
        # "/train/kidney_1_dense",
        "/train/kidney_2",
        # "/train/kidney_3_sparse",
    ],
)

Loading /train/kidney_2/images from cache
Loading /train/kidney_2/labels from cache


In [3]:
train.labels[0].shape

torch.Size([1, 2217, 1041, 1511])

In [3]:
model = lutil.UNet3P(
    in_f=8,
    layers=[16, 32, 32, 32, 64, 64, 64],
    block_depth=4,
    connect_depth=6,
    conv=util.nn.Conv2DNormed,
    pool_fn=nn.MaxPool2d,
    resize_kernel=(2,2),
    upsample_mode='bilinear',
    norm_fn=nn.BatchNorm2d,
    dropout=(nn.Dropout2d, 0.1)
).to(device)

model.load_state_dict(torch.load('./bin/models/unet2d_3dinput_50epochs.pt', map_location=device))
model.requires_grad_(False)
model.eval()
patcher = Patch2p5D(model).to(device)
patcher.requires_grad_(False)
patcher.eval()

Patch2p5D(
  (model): UNet3P(
    (input_norm): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (down_blocks): ModuleList(
      (0): Sequential(
        (0): ConvBlock(
          (layers): ModuleList(
            (0): Sequential(
              (0): Dropout2d(p=0.1, inplace=False)
              (1): Conv2DNormed(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): GELU(approximate='none')
            )
            (1-3): 3 x Sequential(
              (0): Dropout2d(p=0.1, inplace=False)
              (1): Conv2DNormed(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): GELU(approximate='none')
            )
          )
        )
      )
      (1): Sequential(
        (0): MaxPool2d(kernel_size=(2, 2), stride

In [4]:
inference = ScanInference2p5D(
    patcher,
    # patch_size=patch_size,
    batch_size=128,
    # is2d=True
)
with torch.no_grad():
    out_mask = inference(train.scans[0], device)
torch.save(out_mask, './bin/pred_masks/kidney_2_pred.pt')

  0%|          | 0/6 [00:00<?, ?it/s]

RuntimeError: [enforce fail at inline_container.cc:424] . unexpected pos 448 vs 342

In [7]:
torch.save(out_mask, './bin/pred_masks/kidney_2_pred.pt')

In [3]:
out_mask = torch.load('./bin/pred_masks/kidney_2_pred.pt')

In [4]:
pmask = out_mask > 0.5
del out_mask

In [5]:
# pmask = out_mask > 0.5
det_mask = torch.cat([ #collect TP, FP, FN
    train.labels[0] * pmask, # true positive
    (~train.labels[0]) * pmask, # false positive
    train.labels[0] * (~pmask) # false negative
], dim=0)

In [12]:
det_mask[0].sum(), det_mask[1].sum(), det_mask[2].sum()

(tensor(11631554), tensor(626744), tensor(2656955))

In [17]:
1000 * 1500 * 100

150000000

In [18]:
util.DiceScore()(pmask, train.labels[0].bool())

tensor(0.8763)

In [7]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ipywidgets as widgets

class Display:
    def __init__(self, scan: torch.Tensor = None, mask: torch.Tensor = None):
        self.scan = scan
        self.mask = mask

    def _view_slice(self, i: int, slice_dim: int, ax: plt.Axes = None):
        ax.set_facecolor('black')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        slice_idx = [slice(None), slice(None), slice(None)]
        slice_idx[slice_dim] = i

        if self.scan is not None:
            ax.imshow(self.scan[0][tuple(slice_idx)], cmap='gray')
        if self.mask is not None:
            if self.mask.shape[0] == 1:
                ax.imshow(self.mask[0][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'blue']), alpha=0.5)
            else:
                ax.imshow(self.mask[0][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'green']), alpha=0.5)
                ax.imshow(self.mask[1][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'red']), alpha=0.5)
                ax.imshow(self.mask[2][tuple(slice_idx)], cmap=mcolors.ListedColormap(['none', 'yellow']), alpha=0.5)

    @staticmethod
    def _view_slices(i: int, displays: list['Display'], slice_dim: int):
        _, axs = plt.subplots(1, len(displays), figsize=(15, 15))
        if len(displays) == 1:
            axs = [axs]
        for ax, display in zip(axs, displays):
            display._view_slice(i, slice_dim, ax)

    @staticmethod
    def view(displays: list['Display'], slice_dim: int = 0):
        slider_max = displays[0].scan.shape[slice_dim+1] - 1 if displays[0].scan is not None else displays[0].mask.shape[slice_dim+1] - 1
        slider  = widgets.IntSlider(min=0, max=slider_max, step=1, value=0)
        widgets.interact(Display._view_slices, i=slider, displays=widgets.fixed(displays), slice_dim=widgets.fixed(slice_dim))

In [8]:
Display.view([Display(train.scans[0]), Display(mask=det_mask)])#, Display(mask=out_mask)])

interactive(children=(IntSlider(value=0, description='i', max=2216), Output()), _dom_classes=('widget-interact…

In [8]:
import pandas as pd

ModuleNotFoundError: No module named 'pandas'