In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append('..')
import util

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

fine_patcher = util.UNet3D(
    layers=[16, 32, 64, 128, 16],
    Conv3d=util.nn.Conv3DNormed,
    block_depth=4,
    connect_depth=8,
    dropout=0.2,
).to(device)
fine_patcher.load_state_dict(torch.load('./bin/models/unet3d.pt', map_location=device))
fine_patcher.requires_grad_(False)
fine_patcher.eval()

coarse_patcher = util.UNet3D(
    layers=[16, 32, 64, 64, 128, 128, 256],
    Conv3d=util.nn.Conv3DNormed,
    block_depth=8,
    connect_depth=4,
    dropout=0.2,
).to(device)
coarse_patcher.load_state_dict(torch.load('./bin/_tmp_models/wide_patcher_finder_IN_PROGRESS.pt', map_location=device))
coarse_patcher.requires_grad_(False)
coarse_patcher.eval()

patch_size = 64
chunk_size = 16
nchunks = patch_size // chunk_size
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    samples=[
        # "/train/kidney_1_dense",
        # "/train/kidney_3_sparse",
        "/train/kidney_2",
    ],
)

Loading /train/kidney_2/images from cache
Loading /train/kidney_2/labels from cache


In [2]:
inference = util.ScanInferece(
    coarse_patcher,
    fine_patcher,
    batch_size=128,
)
with torch.no_grad():
    out_mask = inference(train.scans[0])

  0%|          | 0/98 [00:00<?, ?it/s]

In [3]:
import os

os.makedirs('/root/data/model_outputs', exist_ok=True)

torch.save(out_mask, '/root/data/model_outputs/kidney_2_3dmask.pt')

In [3]:
msk_2d = torch.load("./bin/output/2d_segmentation.pt").unsqueeze(0) > 75
full_msk = torch.clamp(msk_2d + out_mask, 0, 1)

torch.save(full_msk, '/root/data/output/combined_msk.pt')

In [2]:
full_msk = torch.load('/root/data/output/kidney_2_3dmask.pt')
util.DiceScore()(full_msk, train.labels[0])

FileNotFoundError: [Errno 2] No such file or directory: './bin/output/kidney_2_3dmask.pt'

In [3]:
util.Display(full_msk.squeeze())(), util.Display(train.labels[0].squeeze())()

NameError: name 'full_msk' is not defined

In [None]:
from typing import Callable
from util import SweepCube
from tqdm.auto import tqdm

class FilterNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # assert len(layers) >= 2, "layers must have at least 2 elements"

        kernel_size = 9
        layers = [2, 2, 2, 2, 2,2,2,2, 3]

        c = layers[0]
        # L = []
        self.layers = nn.ModuleList()
        for i, l in enumerate(layers[1:-1]):
            self.layers.append(nn.Sequential(
                util.FFTConv3d(c, l, kernel_size),
                torch.nn.GELU(),
            ))
            c = l
        self.layers.append(
            util.FFTConv3d(c, layers[-1], kernel_size)
        )

    def forward(self, x):
        z = self.layers[0](x)
        for layer in self.layers[1:]:
            z = layer(z)
        return z

class ProcessScan(nn.Module):
    def __init__(self, patching_fn: Callable[[torch.Tensor], torch.Tensor], in_patch_size: int, out_patch_size: int, batch_size: int = 32):
        super().__init__()
        assert in_patch_size % out_patch_size == 0, "in_patch_size must be a multiple of out_patch_size"

        self.patcher = patching_fn
        self.in_patch_size = in_patch_size
        self.out_patch_size = out_patch_size
        self.batch_size = batch_size

    def forward(self, x: torch.Tensor, working_device: str = "cuda") -> torch.Tensor:
        agg_pred = torch.zeros_like(x)
        scan_loader = DataLoader(
            SweepCube(x, self.in_patch_size, self.out_patch_size),
            batch_size=self.batch_size,
            shuffle=True
        )
    
        for xs, positions in tqdm(scan_loader):
            xs = xs.to(working_device).float()
            pred_mask = self.patcher(xs)

            for p, pos in zip(pred_mask.cpu(), positions):
                agg_pred[
                    :,
                    pos[0]:pos[0] + self.coarse_patch_size,
                    pos[1]:pos[1] + self.coarse_patch_size,
                    pos[2]:pos[2] + self.coarse_patch_size,
                ] = p
        
        return agg_pred


In [13]:
def to_chunks(x: torch.Tensor) -> torch.Tensor:
    return x.unfold(2, chunk_size, chunk_size) \
        .unfold(3, chunk_size, chunk_size) \
        .unfold(4, chunk_size, chunk_size)
        # .reshape(-1, 1, chunk_size, chunk_size, chunk_size)

def assemble_patch(x : torch.Tensor) -> torch.Tensor:
    return x.reshape(batch_size, nchunks, nchunks, nchunks, 1, chunk_size, chunk_size, chunk_size) \
        .permute(0, 4, 1, 5, 2, 6, 3, 7) \
        .reshape(batch_size, 1, patch_size, patch_size, patch_size)

dice_fn = util.DiceScore().to(device)

batch_size = 10
train_data = DataLoader(train, batch_size=batch_size, shuffle=True)

# Coarse patcher prediction
x, y, pos = next(iter(train_data))
x, y = x.to(device).float(), y.to(device).float()

# cutoff_density = 0.1
# y_filter = y.mean(dim=(1, 2, 3, 4)) < cutoff_density
# x, y = x[y_filter], y[y_filter]

p_y = [F.sigmoid(logits) for logits in coarse_patcher(x, up_depth=2)]
pred_masks = [(p > 0.5).float() for p in p_y]
masks = coarse_patcher.deep_masks(y)[1:3]
acc = [torch.eq(pred_masks[i], masks[i]).float().mean() for i in range(len(pred_masks))]
dice = [dice_fn(pred_masks[i], masks[i]) for i in range(len(pred_masks))]

# Filter fine patcher input
fine_input = to_chunks(x)[pred_masks[-1].bool()].unsqueeze(1)

# Fine patcher prediction
p_y_fine = [F.sigmoid(logits) for logits in fine_patcher(fine_input)]
pred_masks_fine = [(p > 0.5).float() for p in p_y_fine]
masks_fine = fine_patcher.deep_masks(y)[1:]

# Stitch fine patches together
full_pred = torch.zeros_like(y)
to_chunks(full_pred)[pred_masks[-1].bool()] = pred_masks_fine[-1].squeeze(1)


# Logging
print(dice_fn(full_pred, y))
def fp(pred, true):
    true_positives = (pred * true).sum()
    false_positives = (pred * (1 - true)).sum()
    false_negatives = ((1 - pred) * true).sum()
    return torch.stack([false_positives, false_negatives, true_positives, true.mean() * 1000])

print(torch.stack([torch.cat([torch.tensor([i], device=device), fp(p, t), pos[i].to(device)]).int() for i,(p, t) in enumerate(zip(full_pred, y))]))
print(len(full_pred))

tensor(0.7861, device='cuda:0')
tensor([[    0,    17,   166,   280,     1,   858,   380,   579],
        [    1,    76,     2,   234,     0,   200,  1296,   734],
        [    2,     0,     0,     0,     0,   874,    98,   890],
        [    3,   269,    94,  2868,    11,   247,   758,   994],
        [    4,  7345,   939, 10840,    44,   600,   712,   700],
        [    5,    47,    60,     5,     0,   692,   619,   353],
        [    6,   475,   551,  4229,    18,   194,   688,   571],
        [    7,     0,     0,     0,     0,   447,   987,   375],
        [    8,     0,     0,     0,     0,   718,   595,   104],
        [    9,     0,     0,     0,     0,   814,  1307,   725]],
       device='cuda:0', dtype=torch.int32)
10


In [14]:
cutoff_density = 0.1 # 0.5 ?

i = 4
print(y[i].mean())
util.Display(full_pred[i].squeeze().cpu())(), util.Display(x[i].squeeze().cpu(), y[i].squeeze().cpu())()

tensor(0.0449, device='cuda:0')


interactive(children=(IntSlider(value=0, description='idx_td', max=63), Output()), _dom_classes=('widget-inter…

interactive(children=(IntSlider(value=0, description='idx_td', max=63), Output()), _dom_classes=('widget-inter…

(None, None)

In [None]:
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    samples=[
        # "/train/kidney_1_dense",
        # "/train/kidney_3_sparse",
        "/train/kidney_2",
    ],
    data_dir = "../../data",
)

In [28]:
util.Display(train.labels[0].squeeze())()

interactive(children=(IntSlider(value=0, description='idx_td', max=1510), Output()), _dom_classes=('widget-int…

In [None]:
nruns = 1000

class AggDice(nn.Module):
    def __init__(self):
        super().__init__()
        # self.agg_intersection = 0
        # self.agg_cardinality = 0

        self.true_positives = 0
        self.false_positives = 0
        self.false_negatives = 0

    def forward(self, y_pred, y_true):
        # self.agg_intersection += (y_pred * y_true).sum()
        # self.agg_cardinality += (y_pred + y_true).sum()
        # return 2 * self.agg_intersection / self.agg_cardinality

        self.true_positives += (y_pred * y_true).sum()
        self.false_positives += (y_pred * (1 - y_true)).sum()
        self.false_negatives += ((1 - y_pred) * y_true).sum()

        return 2 * self.true_positives / (2 * self.true_positives + self.false_positives + self.false_negatives)

dice_fn = AggDice()

with torch.no_grad():
    for i in range(nruns):
        x, y, _ = next(iter(valid_data))
        x, y = x.float().to(device), y.float().to(device)

        # Compute output
        p_y = fine_patcher(x, threshold=0.51)
        # p_y[-1] += p_y[-1] * F.interpolate(p_y[-2] - 0.5, size=16, mode='nearest') * 0.5
        # p_y[-1] += F.interpolate(p_y[-3] - 0.5, size=16, mode='nearest') * 0.1
        pred_masks = [(p > 0.5).float() for p in p_y]
        masks = fine_patcher.unet.deep_masks(y)

        # p_y[-1] -= F.interpolate(((p_y[-2]) < 0.5).float() - 0.5, size=16, mode='nearest') * 0.1
        p_y[-1] -= F.interpolate((p_y[-2] < 0.1).float(), size=16, mode='nearest') * \
             F.interpolate((p_y[-3] < 0.1).float(), size=16, mode='nearest') * \
            F.interpolate((p_y[-4] < 0.1).float(), size=16, mode='nearest') * 0.01
        p_y[-1] += F.interpolate((p_y[-2] > 0.5).float(), size=16, mode='nearest') * \
             F.interpolate((p_y[-3] > 0.5).float(), size=16, mode='nearest') * \
            F.interpolate((p_y[-4] > 0.5).float(), size=16, mode='nearest') * 0.01
        # p_y[-1] -= F.interpolate((p_y[-4] < 0.5).float(), size=16, mode='nearest') * 0.01

        pred_masks = [(p > 0.5).float() for p in p_y]

        if i % 10 == 0:
            print('Dice: ', dice_fn(pred_masks[-1], masks[-1]).item())
            print('true_positives: ', dice_fn.true_positives)# / (i + 1))
            print('false_positives: ', dice_fn.false_positives)# / (i + 1))
            print('false_negatives: ', dice_fn.false_negatives)# / (i + 1))

tensor(0.8576, device='cuda:0')


interactive(children=(IntSlider(value=0, description='idx_td', max=63), Output()), _dom_classes=('widget-inter…

interactive(children=(IntSlider(value=0, description='idx_td', max=63), Output()), _dom_classes=('widget-inter…

(None, None)

In [3]:

# to_chunks(x)[pred_masks[-1].bool()].unsqueeze(1).shape

In [85]:
pred_masks_fine[-1].shape

torch.Size([275, 1, 16, 16, 16])

In [7]:
full_pred.shape

torch.Size([22, 1, 64, 64, 64])

In [8]:
x = assemble_patch(x)
y = assemble_patch(y)
pred_masks = assemble_patch(pred_masks[-1])

In [12]:
dice_fn(full_pred, y)

tensor(0.8498, device='cuda:0')

interactive(children=(IntSlider(value=0, description='idx_td', max=63), Output()), _dom_classes=('widget-inter…

interactive(children=(IntSlider(value=0, description='idx_td', max=63), Output()), _dom_classes=('widget-inter…

(None, None)

In [1]:
from time import time

import sys
sys.path.append('..')
import util

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = util.UNet3D(
#     Conv3d=util.nn.Conv3DNormed,
#     block_depth=8,
#     dropout=0.2,
# ).to(device)
# model.load_state_dict(torch.load('./bin/models/unet3d.pt', map_location=device))
# model.eval()


fine_patcher = util.Patcher(
    util.UNet3D(
        Conv3d=util.nn.Conv3DNormed,
        block_depth=8,
        dropout=0.2,
    ),
    layers = [5, 64, 128, 256, 512, 1024, 1],
    Conv3d=util.nn.Conv3DNormed,
    dropout=0.2,
    depth=8,
).to(device)
fine_patcher.load_state_dict(torch.load('./bin/models/patcher.pt', map_location=device))
fine_patcher.eval()

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0114 22:45:24.482801 140600046679232 program.py:298] TensorBoard could not bind to port 6005, it was already in use
ERROR: TensorBoard could not bind to port 6005, it was already in use


Patcher(
  (unet): UNet3D(
    (input_norm): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (down_blocks): ModuleList(
      (0): Sequential(
        (0): ConvBlock(
          (layers): ModuleList(
            (0): Sequential(
              (0): Dropout3d(p=0.2, inplace=False)
              (1): Conv3DNormed(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): GELU(approximate='none')
            )
            (1-7): 7 x Sequential(
              (0): Dropout3d(p=0.2, inplace=False)
              (1): Conv3DNormed(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): GELU(approximate='none')
            )
          )
        )
      )
      (1): Sequential(
        (0): MaxPool3d(kernel_size

In [2]:
patch_size = 16
batch_size = 32

test = util.data.SenNet(
    patch_size,
    guarantee_vessel=0,
    samples=[
        "/train/kidney_2"
        # "/train/kidney_3_sparse"
    ]
)
valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)

Loading /train/kidney_2/images from cache
Loading /train/kidney_2/labels from cache


In [4]:
nruns = 1000

class AggDice(nn.Module):
    def __init__(self):
        super().__init__()
        # self.agg_intersection = 0
        # self.agg_cardinality = 0

        self.true_positives = 0
        self.false_positives = 0
        self.false_negatives = 0

    def forward(self, y_pred, y_true):
        # self.agg_intersection += (y_pred * y_true).sum()
        # self.agg_cardinality += (y_pred + y_true).sum()
        # return 2 * self.agg_intersection / self.agg_cardinality

        self.true_positives += (y_pred * y_true).sum()
        self.false_positives += (y_pred * (1 - y_true)).sum()
        self.false_negatives += ((1 - y_pred) * y_true).sum()

        return 2 * self.true_positives / (2 * self.true_positives + self.false_positives + self.false_negatives)

dice_fn = AggDice()

with torch.no_grad():
    for i in range(nruns):
        x, y, _ = next(iter(valid_data))
        x, y = x.float().to(device), y.float().to(device)

        # Compute output
        p_y = fine_patcher(x, threshold=0.51)
        # p_y[-1] += p_y[-1] * F.interpolate(p_y[-2] - 0.5, size=16, mode='nearest') * 0.5
        # p_y[-1] += F.interpolate(p_y[-3] - 0.5, size=16, mode='nearest') * 0.1
        pred_masks = [(p > 0.5).float() for p in p_y]
        masks = fine_patcher.unet.deep_masks(y)

        # p_y[-1] -= F.interpolate(((p_y[-2]) < 0.5).float() - 0.5, size=16, mode='nearest') * 0.1
        p_y[-1] -= F.interpolate((p_y[-2] < 0.1).float(), size=16, mode='nearest') * \
             F.interpolate((p_y[-3] < 0.1).float(), size=16, mode='nearest') * \
            F.interpolate((p_y[-4] < 0.1).float(), size=16, mode='nearest') * 0.01
        p_y[-1] += F.interpolate((p_y[-2] > 0.5).float(), size=16, mode='nearest') * \
             F.interpolate((p_y[-3] > 0.5).float(), size=16, mode='nearest') * \
            F.interpolate((p_y[-4] > 0.5).float(), size=16, mode='nearest') * 0.01
        # p_y[-1] -= F.interpolate((p_y[-4] < 0.5).float(), size=16, mode='nearest') * 0.01

        pred_masks = [(p > 0.5).float() for p in p_y]

        if i % 10 == 0:
            print('Dice: ', dice_fn(pred_masks[-1], masks[-1]).item())
            print('true_positives: ', dice_fn.true_positives)# / (i + 1))
            print('false_positives: ', dice_fn.false_positives)# / (i + 1))
            print('false_negatives: ', dice_fn.false_negatives)# / (i + 1))

Dice:  0.1882352977991104
true_positives:  tensor(8., device='cuda:0')
false_positives:  tensor(42., device='cuda:0')
false_negatives:  tensor(27., device='cuda:0')
Dice:  0.6639999747276306
true_positives:  tensor(83., device='cuda:0')
false_positives:  tensor(57., device='cuda:0')
false_negatives:  tensor(27., device='cuda:0')
Dice:  0.654275119304657
true_positives:  tensor(88., device='cuda:0')
false_positives:  tensor(59., device='cuda:0')
false_negatives:  tensor(34., device='cuda:0')
Dice:  0.0640350878238678
true_positives:  tensor(146., device='cuda:0')
false_positives:  tensor(4133., device='cuda:0')
false_negatives:  tensor(135., device='cuda:0')
Dice:  0.06378331035375595
true_positives:  tensor(146., device='cuda:0')
false_positives:  tensor(4151., device='cuda:0')
false_negatives:  tensor(135., device='cuda:0')
Dice:  0.26547175645828247
true_positives:  tensor(785., device='cuda:0')
false_positives:  tensor(4197., device='cuda:0')
false_negatives:  tensor(147., device='c

KeyboardInterrupt: 

In [160]:
pred_masks[-1].squeeze().shape

torch.Size([32, 16, 16, 16])

tensor(0.8762, device='cuda:0')

In [331]:
i = 10
print(util.DiceScore()(pred_masks[-1][i], masks[-1][i]))
print(masks[-1][i].sum())
util.Display(pred_masks[-1][i].squeeze().cpu())(), util.Display(x[i].cpu().squeeze(), masks[-1][i].squeeze().cpu())()

tensor(3.9683e-09, device='cuda:0')
tensor(0., device='cuda:0')


interactive(children=(IntSlider(value=0, description='idx_td', max=15), Output()), _dom_classes=('widget-inter…

interactive(children=(IntSlider(value=0, description='idx_td', max=15), Output()), _dom_classes=('widget-inter…

(None, None)