In [1]:
from time import time

import sys
sys.path.append('..')
import util

import torch
import torch.nn as nn
import torch.nn.functional as F

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
patch_size = 16
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    samples=[
        "/train/kidney_1_dense",
        "/train/kidney_3_sparse"
    ]
)

test = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.1,
    samples=[
        "/train/kidney_2"
    ]
)

Loading /train/kidney_1_dense/images from cache


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0116 02:59:58.550872 140171716478144 program.py:298] TensorBoard could not bind to port 6005, it was already in use
ERROR: TensorBoard could not bind to port 6005, it was already in use


Loading /train/kidney_1_dense/labels from cache
Loading /train/kidney_3_sparse/images from cache
Loading /train/kidney_3_sparse/labels from cache
Loading /train/kidney_2/images from cache
Loading /train/kidney_2/labels from cache


In [3]:
log_board.clear('train')
log_board.clear('val')

In [4]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# b = 22
# x = torch.randn(b, 1, 16, 16).to(device)
# probs = F.sigmoid(torch.randn(b, 1, 1, 1, 1)).to(device)

# mask = (probs > 0.5).squeeze().unsqueeze(1)

# mask.squeeze().unsqueeze(1).shape, mask.sum(), mask.squeeze(), x[mask].shape

In [5]:
test.num_patches / 32, test.num_patches / 32

(28408.28125, 28408.28125)

In [6]:
from torch.utils.data import DataLoader

save_every = 500
epochs = 25
batch_size = 32
t_logger = log_board.get_logger('train')
v_logger = log_board.get_logger('val')

train_data = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)

# Setup model
unet = util.UNet3D(
    Conv3d=util.nn.Conv3DNormed,
    block_depth=8,
    dropout=0.2,
)
# unet.load_state_dict(torch.load('./bin/models/unet3d2.pt', map_location=device))
patcher = util.Patcher(
    unet,
    layers = [2, 64, 128, 256, 512, 1024, 1],
    Conv3d=util.nn.Conv3DNormed,
    dropout=0.2,
    depth=8,
).to(device)
patcher.load_state_dict(torch.load('./bin/models/patcher.pt', map_location=device))
patcher.train()
# unet.requires_grad_(False)
# unet.eval()

# optimizer & loss
optimizer = torch.optim.Adam(patcher.net.parameters(), lr=0.0002)
dice_fn = util.DiceScore().to(device)
focal_loss = util.BinaryFocalLoss(
    alpha=0.5,
    gamma=1.8,
)
loss_fn = torch.nn.BCELoss()



# Training

step = 0
t = time()
try:
    for epoch in range(epochs):
        for i, (x, y, _) in enumerate(train_data):
            step += 1

            # Prepare data
            aug = util.PatchAugment().to(device)
            x, y = aug(x.float().to(device)), aug(y.float().to(device))

            
            p_y = patcher(x, full_output=False)
            pred_masks = [(p > 0.5).float() for p in p_y]
            masks = unet.deep_masks(y)

            # Compute loss
            dlayers = []
            flyaers = [0]
            dloss = torch.stack([
                (1 - dice_fn(p_y[i], masks[i]))
                for i in dlayers
            ]) if len(dlayers) > 0 else torch.zeros(0)
            floss = torch.stack([
                focal_loss(p_y[i], masks[i])
                for i in flyaers
            ]) if len(flyaers) > 0 else torch.zeros(0)
            loss = dloss.sum() * 0.1 + floss.sum()

            # Step grad
            loss.backward()
            # torch.nn.utils.clip_grad_value_(patcher.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()

            # Logging & validation
            with torch.no_grad():
                # log loss
                loss_dict = { f'dice_{k}' : dloss[i] for i, k in enumerate(dlayers) }
                loss_dict.update({ f'focal_{k}' : floss[i] for i, k in enumerate(flyaers) })
                t_logger.add_scalars('loss', loss_dict, step)
                
                # performance
                # accuracy = [torch.eq(pred, mask).float().mean() for pred, mask in zip(pred_masks, masks)]
                # dice = [dice_fn(pred, mask) for pred, mask in zip(pred_masks[1:], masks[1:])]
                # perf_dict = { f'acc_{i}' : acc for i, acc in enumerate(accuracy) }
                # perf_dict.update({ f'dice_{i+1}' : d for i, d in enumerate(dice) })
                # t_logger.add_scalars('performance', perf_dict, step)

                # Performance logging
                if (step + 1) % 10 == 0:
                    # training performance
                    sample_mask = pred_masks[0].squeeze().bool()
                    pred_masks = [pred_masks[0]] + [torch.zeros_like(m) for m in masks[1:]]
                    for i, p in enumerate(patcher.unet(x[sample_mask])):
                        pred_masks[i + 1][sample_mask] = (p > 0.5).float()

                    accuracy = [torch.eq(pred, mask).float().mean() for pred, mask in zip(pred_masks, masks)]
                    dice = [dice_fn(pred, mask) for pred, mask in zip(pred_masks[1:], masks[1:])]
                    perf_dict = { f'acc_{i}' : acc for i, acc in enumerate(accuracy) }
                    perf_dict.update({ f'dice_{i+1}' : d for i, d in enumerate(dice) })
                    t_logger.add_scalars('performance', perf_dict, step)

                    # stats
                    for i, (prob, pred, mask) in enumerate(zip(p_y, pred_masks, masks)):
                        t_logger.add_scalars(f'stats_{i + 1}',{
                            'prob_std': prob.std(),
                            'prob_mean': prob.mean(),
                            'pred_std': pred.std(),
                            'pred_mean': pred.mean(),
                            'mask_std': mask.std(),
                            'mask_mean': mask.mean(),
                        }, step)

                    # validation performance
                    patcher.eval()
                    x, y, _ = next(iter(valid_data))
                    x, y = x.float().to(device), y.float().to(device)
                    
                    p_y = patcher(x) # full output
                    pred_masks = [(p > 0.5).float() for p in p_y]
                    masks = unet.deep_masks(y)
                    
                    sample_mask = pred_masks[0].squeeze().bool()
                    pred_masks = [pred_masks[0]] + [torch.zeros_like(m) for m in masks[1:]]
                    for i, p in enumerate(patcher.unet(x[sample_mask])):
                        pred_masks[i + 1][sample_mask] = (p > 0.5).float()

                    accuracy = [torch.eq(pred, mask).float().mean() for pred, mask in zip(pred_masks, masks)]
                    dice = [dice_fn(pred, mask) for pred, mask in zip(pred_masks[1:], masks[1:])]
                    perf_dict = { f'acc_{i}' : acc for i, acc in enumerate(accuracy) }
                    perf_dict.update({ f'dice_{i+1}' : d for i, d in enumerate(dice) })
                    v_logger.add_scalars('performance', perf_dict, step)

                    v_logger.add_image('masks', util.mask_plots(x, masks, pred_masks), step, dataformats='HWC')
                    t_logger.add_scalar('time', time() - t, step)
                    t = time()
                    patcher.train()
                
                # Save model
                if (step + 1) % save_every == 0:
                    torch.save(patcher.state_dict(), f'./bin/models/patcherIN_PROGRESS.pt')
except Exception as e:
    print(e)
    # torch.save(patcher.state_dict(), f'./bin/models/patcher.pt')
    raise e

In [5]:
import os

os.makedirs('./bin/models', exist_ok=True)

torch.save(patcher.state_dict(), f'./bin/models/patcher.pt')

In [9]:
upsample = F.interpolate
x = torch.randn(1, 1, 4, 4, 4).to(device)
upsample(x).shape

AttributeError: 'int' object has no attribute 'dim'