In [1]:
from time import time

import sys
sys.path.append('..')
import util

import torch
import torch.nn as nn
import torch.nn.functional as F

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
patch_size = 64
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=1.0,
    samples=[
        "/train/kidney_1_dense",
        "/train/kidney_3_sparse"
    ]
)

test = util.data.SenNet(
    patch_size,
    guarantee_vessel=1.0,
    samples=[
        "/train/kidney_2"
    ]
)

Loading /train/kidney_1_dense/images from cache


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0116 06:36:37.304590 140395134047424 program.py:298] TensorBoard could not bind to port 6005, it was already in use
ERROR: TensorBoard could not bind to port 6005, it was already in use


Loading /train/kidney_1_dense/labels from cache
Loading /train/kidney_3_sparse/images from cache
Loading /train/kidney_3_sparse/labels from cache
Loading /train/kidney_2/images from cache
Loading /train/kidney_2/labels from cache


In [3]:
log_board.clear('train')
log_board.clear('val')

In [4]:
from torch.utils.data import DataLoader

save_every = 1000
epochs = 25
batch_size = 128
t_logger = log_board.get_logger('train')
v_logger = log_board.get_logger('val')

train_data = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)

model = util.UNet3D(
    layers=[16, 32, 64, 128, 16],
    Conv3d=util.nn.Conv3DNormed,
    block_depth=4,
    connect_depth=8,
    dropout=0.2,
).to(device)
# model.load_state_dict(torch.load('./bin/models/unet3d.pt', map_location=device))
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
dice_fn = util.DiceScore().to(device)
focal_loss = util.BinaryFocalLoss(
    alpha=0.5,
    gamma=1.5,
)

t = time()
step = 0
for epoch in range(epochs):
    for i, (x, y, _) in enumerate(train_data):
        step += 1

        # Prepare data
        aug = util.PatchAugment().to(device)
        x, y = aug(x.float().to(device)), aug(y.float().to(device))

        # Compute output
        logits = model(x)
        p_y = [torch.sigmoid(logit) for logit in logits]
        pred_masks = [(p > 0.5).float() for p in p_y]
        masks = model.deep_masks(y)[1:]

        # Compute loss
        dlayers = [0,1,2,3]
        flyaers = [0]
        dloss = torch.stack([
            (1 - dice_fn(p_y[i], masks[i]))
            for i in dlayers
        ])
        floss = torch.stack([
            focal_loss(p_y[i], masks[i])
            for i in flyaers
        ])
        loss = dloss.sum() + floss.sum()

        # Step grad
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        optimizer.step()
        optimizer.zero_grad()

        # Logging & validation
        with torch.no_grad():
            # log loss
            loss_dict = { f'dice_{k}' : dloss[i] for i, k in enumerate(dlayers) }
            loss_dict.update({ f'focal_{k}' : floss[i] for i, k in enumerate(flyaers) })
            t_logger.add_scalars('loss', loss_dict, step)
            
            # performance
            acc = [torch.eq(pred_masks[i], masks[i]).float().mean() for i in dlayers]
            dice = [dice_fn(pred_masks[i], masks[i]) for i in dlayers]
            perf_dict = { f'acc_{k}' : acc[i] for i, k in enumerate(dlayers) }
            perf_dict.update({ f'dice_{k}' : dice[i] for i, k in enumerate(dlayers) })
            t_logger.add_scalars('performance', perf_dict, step)

            # stats
            for i, (prob, pred, mask) in enumerate(zip(p_y[1:], pred_masks[1:], masks[1:])):
                t_logger.add_scalars(f'stats_{i + 1}',{
                    'prob_std': prob.std(),
                    'prob_mean': prob.mean(),
                    'pred_std': pred.std(),
                    'pred_mean': pred.mean(),
                    'mask_std': mask.std(),
                    'mask_mean': mask.mean(),
                }, step)

            # Validation logging
            if (step + 1) % 150 == 0:
                model.eval()
                x, y, _ = next(iter(valid_data))
                x, y = x.float().to(device), y.float().to(device)

                # Compute output
                logits = model(x)
                p_y = [torch.sigmoid(logit) for logit in logits]
                pred_masks = [p > 0.5 for p in p_y]
                masks = model.deep_masks(y)[1:]

                acc = [torch.eq(pred_masks[i], masks[i]).float().mean() for i in dlayers]
                dice = [dice_fn(pred_masks[i], masks[i]) for i in dlayers]
                perf_dict = { f'acc_{k}' : acc[i] for i, k in enumerate(dlayers) }
                perf_dict.update({ f'dice_{k}' : dice[i] for i, k in enumerate(dlayers) })
                v_logger.add_scalars('performance', perf_dict, step)

                v_logger.add_image('masks', util.mask_plots(x, masks, pred_masks), step, dataformats='HWC')
                t_logger.add_scalar('time', time() - t, step)
                t = time()
                model.train()
                
            # Save model
            if (step + 1) % save_every == 0:
                torch.save(model.state_dict(), f'./bin/_tmp_models/unet_IN_PROGRESS.pt')

In [6]:
import os

os.makedirs('./bin/models', exist_ok=True)

torch.save(model.state_dict(), f'./bin/models/unet3d2.pt')

In [6]:
import torch
x = torch.rand(32, 1, 16, 16, 16)
x.amax((2,3,4)).shape

torch.Size([32, 1])