In [1]:
from time import time

import sys
sys.path.append('..')
import util
import lutil

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
patch_size = 8,256,256
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.4,
    samples=[
        "/train/kidney_1_dense",
        "/train/kidney_3_sparse"
    ]
)

test = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.4,
    samples=[
        "/train/kidney_2"
    ]
)

Loading /train/kidney_1_dense/images from cache


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0121 23:44:08.628577 140664648578240 program.py:298] TensorBoard could not bind to port 6005, it was already in use
ERROR: TensorBoard could not bind to port 6005, it was already in use


Loading /train/kidney_3_sparse/images from cache
Loading /train/kidney_1_dense/labels from cache


In [None]:
log_board.clear('train')
log_board.clear('val')

In [4]:
from torch.utils.data import DataLoader

save_every = 500
valid_every = 25
epochs = 800
batch_size = 24
t_logger = log_board.get_logger('train')
v_logger = log_board.get_logger('val')

train_data = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)

model = lutil.UNet3P(
    in_f=8,
    layers=[16, 32, 32, 32, 64, 64, 64],
    block_depth=4,
    connect_depth=6,
    conv=util.nn.Conv2DNormed,
    pool_fn=nn.MaxPool2d,
    resize_kernel=(2,2),
    upsample_mode='bilinear',
    norm_fn=nn.BatchNorm2d,
    dropout=(nn.Dropout2d, 0.1)
).to(device)

# model = lutil.UNet3P(
#     layers=[16, 32, 32, 32, 32, 32, 16],
#     block_depth=4,
#     connect_depth=4,
#     conv=util.nn.Conv3DNormed,
#     pool_fn=nn.MaxPool3d,
#     resize_kernel=(2,2,1),
#     upsample_mode='trilinear',
#     norm_fn=nn.BatchNorm3d,
#     dropout=(nn.Dropout3d, 0.1)
# ).to(device)
model.load_state_dict(torch.load('./bin/_tmp_models/unet2.5d_IN_PROGRESS.pt', map_location=device))
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
dice_fn = util.DiceScore().to(device)
focal_loss = util.BinaryFocalLoss(
    alpha=0.8,
    gamma=1.5,
)

t = time()
step = 0
for epoch in range(epochs):
    for i, (x, y, _) in enumerate(train_data):
        step += 1

        # Prepare data
        aug = util.PatchAugment(no_rotate=True).to(device)
        x, y = aug(x.float().to(device)), aug(y.float().to(device))
        # x, y = x.squeeze(-1), y.squeeze(-1)
        x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), y[:,:,y.shape[2]//2,:,:]
        # Compute output
        p_y = [torch.sigmoid(logits) for logits in model(x)]
        pred_masks = [(p > 0.5).float() for p in p_y]
        masks = [m for m in model.deep_masks(y)]

        # Compute loss
        dlayers = [0,1,2,3,4,5]
        flyaers = [0,1,2,3,4,5]
        dloss = torch.stack([
            (1 - dice_fn(p_y[i], masks[i], mode='separate'))
            for i in dlayers
        ])
        dloss[-1] *= 1.1
        floss = torch.stack([
            focal_loss(p_y[i], masks[i])
            for i in flyaers
        ])
        loss = dloss.sum() + floss.sum()

        # Step grad
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        optimizer.step()
        optimizer.zero_grad()

        # Logging & validation
        with torch.no_grad():
            # log loss
            loss_dict = { f'dice_{k}' : dloss[i] for i, k in enumerate(dlayers) }
            loss_dict.update({ f'focal_{k}' : floss[i] for i, k in enumerate(flyaers) })
            t_logger.add_scalars('loss', loss_dict, step)
            
            # performance
            acc = [torch.eq(pred_masks[i], masks[i]).float().mean() for i in dlayers]
            dice = [dice_fn(pred_masks[i], masks[i]) for i in dlayers]
            perf_dict = { f'acc_{k}' : acc[i] for i, k in enumerate(dlayers) }
            perf_dict.update({ f'dice_{k}' : dice[i] for i, k in enumerate(dlayers) })
            t_logger.add_scalars('performance', perf_dict, step)

            # stats
            for i, (prob, pred, mask) in enumerate(zip(p_y[1:], pred_masks[1:], masks[1:])):
                t_logger.add_scalars(f'stats_{i + 1}',{
                    'prob_std': prob.std(),
                    'prob_mean': prob.mean(),
                    'pred_std': pred.std(),
                    'pred_mean': pred.mean(),
                    'mask_std': mask.std(),
                    'mask_mean': mask.mean(),
                }, step)

            # Validation logging
            if (step + 1) % valid_every == 0:
                model.eval()
                x, y, _ = next(iter(valid_data))
                x, y = x.float().to(device), y.float().to(device)
                x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), y[:,:,y.shape[2]//2,:,:]
                # x, y = x.squeeze(-1), y.squeeze(-1)

                # Compute output
                p_y = [torch.sigmoid(logits) for logits in model(x)]
                pred_masks = [(p > 0.5).float() for p in p_y]
                masks = [m for m in model.deep_masks(y)]

                acc = [torch.eq(pred_masks[i], masks[i]).float().mean() for i in dlayers]
                dice = [dice_fn(pred_masks[i], masks[i]) for i in dlayers]
                perf_dict = { f'acc_{k}' : acc[i] for i, k in enumerate(dlayers) }
                perf_dict.update({ f'dice_{k}' : dice[i] for i, k in enumerate(dlayers) })
                v_logger.add_scalars('performance', perf_dict, step)

                v_logger.add_image('masks', util.mask_plots2d(x[:,x.shape[1]//2,:,:].unsqueeze(1), masks, pred_masks), step, dataformats='HWC')
                t_logger.add_scalar('time', time() - t, step)
                t = time()
                model.train()
                
            # Save model
            if (step + 1) % save_every == 0:
                torch.save(model.state_dict(), f'./bin/_tmp_models/unet2.5d_IN_PROGRESS.pt')

OutOfMemoryError: CUDA out of memory. Tried to allocate 252.00 MiB. GPU 0 has a total capacty of 23.64 GiB of which 253.94 MiB is free. Process 404489 has 6.90 GiB memory in use. Process 415332 has 16.37 GiB memory in use. Of the allocated memory 15.39 GiB is allocated by PyTorch, and 539.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [7]:
import os

os.makedirs('./bin/models', exist_ok=True)

torch.save(model.state_dict(), f'./bin/models/unet2d_3dinput_450epochs.pt')

In [8]:
step

21400

In [71]:
p_size = torch.tensor([64,64,8])
k_perm = torch.randperm(3)
patch_size = p_size[k_perm]

x = torch.randn(1,*p_size[k_perm])
# x = x.permute(0, *k_perm+1)
k_perm, patch_size, torch.argsort(k_perm), x.shape, x.permute(0, *torch.argsort(k_perm)+1).shape

(tensor([1, 2, 0]),
 tensor([64,  8, 64]),
 tensor([2, 0, 1]),
 torch.Size([1, 64, 8, 64]),
 torch.Size([1, 64, 64, 8]))

In [45]:
epochs = 10
batch_size = 32

train_data = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)

x, y, _ = next(iter(train_data))
aug = util.PatchAugment(no_rotate=True).to(device)
x, y = aug(x.float().to(device)), aug(y.float().to(device))

x.shape, y.shape

# util.Display(x[0].squeeze().cpu().transpose(0,2), y[0].squeeze().cpu().transpose(0,2) )()

(torch.Size([32, 1, 64, 64, 8]), torch.Size([32, 1, 64, 64, 8]))

In [9]:
# model = lutil.UNet3P(
#     layers=[16, 32, 64, 128, 16],
#     conv=util.nn.Conv3DNormed,
#     block_depth=4,
#     connect_depth=8,
#     norm_fn=nn.BatchNorm3d,
#     resize_kernel=(2,2,1),
#     dropout=(nn.Dropout3d, 0.1)
# ).to(device)

model = lutil.UNet3P(
    layers=[16, 32, 64, 128, 16],
    block_depth=4,
    connect_depth=8,

    conv=nn.Conv2d,
    pool_fn=nn.MaxPool2d,
    resize_kernel=(2,2),
    upsample_mode='bilinear',
    norm_fn=nn.BatchNorm2d,
    dropout=(nn.Dropout2d, 0.1)
).to(device)

In [91]:
import numpy as np

patch_size = (64,64,8)
patch_size = torch.tensor(patch_size)
patch_size

tensor([64, 64,  8])

In [92]:
perm = torch.randperm(len(patch_size))
patch_size = patch_size[perm]
patch_size

tensor([64,  8, 64])

In [86]:
len(p_size)

3

In [48]:
k_perm = np.random.permutation(3)
k_perm
k_perm, p_size[k_perm]

(array([0, 1, 2]), tensor([64, 64,  8]))

In [58]:
[*p_size[k_perm]]

[tensor(8), tensor(64), tensor(64)]

In [67]:
torch.randperm(3)

tensor([2, 0, 1])

In [28]:
k_perm = torch.randperm(3)
k_perm

tensor([2, 0, 1])

In [29]:
torch.argsort(k_perm), k_perm

(tensor([1, 2, 0]), tensor([2, 0, 1]))

(tensor([1, 0, 2]),
 tensor([1, 0, 2]),
 torch.Size([1, 64, 64, 8]),
 torch.Size([1, 64, 64, 8]))

In [74]:
k_perm,torch.argsort(k_perm)

(tensor([0, 2, 1]), tensor([0, 2, 1]))

In [10]:
x, y = torch.randn(1, 1, 128, 128).to(device), torch.randn(1, 1, 128, 128).to(device)

p_y = [F.sigmoid(x) for x in model(x)]
masks = model.deep_masks(y)

[p.shape for p in p_y], [m.shape for m in masks]

([torch.Size([1, 1, 16, 16]),
  torch.Size([1, 1, 32, 32]),
  torch.Size([1, 1, 64, 64]),
  torch.Size([1, 1, 128, 128])],
 [torch.Size([1, 1, 16, 16]),
  torch.Size([1, 1, 32, 32]),
  torch.Size([1, 1, 64, 64]),
  torch.Size([1, 1, 128, 128])])

In [10]:
nn.MaxPool3d(torch.tensor([2,2,1]) ** 4)

MaxPool3d(kernel_size=tensor([16, 16,  1]), stride=tensor([16, 16,  1]), padding=0, dilation=1, ceil_mode=False)

In [None]:
x = torch.randn(1, 1, 64, 64).to(device)