In [1]:
from time import time

import sys
sys.path.append('..')
import util

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
patch_size = 1,256,256
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    data= [
        [
            util.data.kidney_from_cache("kidney_1_dense", "images"),
            util.data.kidney_from_cache("kidney_3_dense", "images")
        ], [
            util.data.kidney_1_fixed(),
            util.data.kidney_from_cache("kidney_3_dense", "labels")
        ]
    ]
    # samples=[
        # "/train/kidney_1_dense",
        # "/train/kidney_3_dense",
        # "/train/kidney_3_sparse"
    # ]
)

test = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    samples=[
        "/train/kidney_2"
        # "/train/kidney_3_sparse"
    ]
)
test.scans[0] = test.scans[0][:, 900:, :, :]
test.labels[0] = test.labels[0][:, 900:, :, :]

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0131 02:16:44.540774 140487363224192 program.py:298] TensorBoard could not bind to port 6005, it was already in use
ERROR: TensorBoard could not bind to port 6005, it was already in use


Loading /train/kidney_2/images from cache
Loading /train/kidney_2/labels from cache


In [3]:
log_board.clear('train')
log_board.clear('val')

In [4]:
from torchvision import transforms as T
from torchvision.transforms import functional as VF
import albumentations as A
import random

# augment


augment_just_scan = T.Compose([
    T.RandomInvert(p=0.5),
    T.GaussianBlur(3, sigma=(0.1, 2.0)),
    # T.ColorJitter(brightness=0.15, contrast=0.15)
])

def augment(x: torch.Tensor, y: torch.Tensor, do_just_x: bool = True) -> (torch.Tensor, torch.Tensor):
    _x, _y = x, y
    if do_just_x:
        _x = augment_just_scan(x)

    if random.random() < 0.5:
        seed = random.randint(0, 2**32)
        elastic_transform = T.ElasticTransform(
            alpha= random.random() * 80.,
            sigma= 15.
        )

        random.seed(seed)
        torch.manual_seed(seed)
        _x = elastic_transform(x)

        random.seed(seed)
        torch.manual_seed(seed)
        _y = elastic_transform(y)

    return _x, _y
    # return augment_just_scan(augment(x)), augment(y)

    # train_aug_list = [
    #     A.Rotate(limit=270, p= 0.5),
    #     A.RandomScale(scale_limit=(0.8,1.25),interpolation=cv2.INTER_CUBIC,p=p_augm),
    #     A.RandomCrop(input_size, input_size,p=1),
    #     A.RandomGamma(p=p_augm*2/3),
    #     A.RandomBrightnessContrast(p=p_augm,),
    #     A.GaussianBlur(p=p_augm),
    #     A.MotionBlur(p=p_augm),
    #     A.GridDistortion(num_steps=5, distort_limit=0.3, p=p_augm),
    #     ToTensorV2(transpose_mask=True),
    # ]
    # train_aug = A.Compose(train_aug_list)
    # valid_aug_list = [
    #     ToTensorV2(transpose_mask=True),
    # ]
    # valid_aug = A.Compose(valid_aug_list)

# def augment(x: torch.Tensor) -> torch.Tensor:
#     # random dimming
#     _x = x * torch.rand(1, device=x.device) * 0.8 + 0.5
#     # _x = _x ** (torch.rand(1, device=x.device) * 2 + 0.5)
#     return _x

In [5]:
from torch.utils.data import DataLoader

save_every = 500
extra_log_every = 25
epochs = 800
batch_size = 8
t_logger = log_board.get_logger('train')
v_logger = log_board.get_logger('val')

train_data = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)
edge = util.Edge().to(device)

# TODO: use this layout next time (small down size, large cross size)
model = util.UNet3P(
    in_f=1,
    layers=[32, 64, 128, 256, 512],
    block_depth=4,
    connect_depth=24,
    conv=util.nn.Conv2DNormed,
    pool_fn=nn.MaxPool2d,
    resize_kernel=(2,2),
    upsample_mode='bilinear',
    norm_fn=nn.InstanceNorm2d,
    # dropout=(nn.Dropout2d, 0.2)
).to(device)
# model = util.UNet3P(
#     in_f=3,
#     layers=[32, 64, 64, 128, 128, 128, 128],
#     block_depth=6,
#     connect_depth=8,
#     conv=util.nn.Conv2DNormed,
#     pool_fn=nn.MaxPool2d,
#     resize_kernel=(2,2),
#     upsample_mode='bilinear',
#     norm_fn=nn.BatchNorm2d,
#     frozen_norm=True,
#     dropout=(nn.Dropout2d, 0.1)
# ).to(device)
# model.load_state_dict(torch.load('./bin/models/pretrain_model_base.pt', map_location=device))
model.train()

# pull out trainable parameters
# for module in model.modules():
#     if isinstance(module, nn.BatchNorm2d):
#         for param in module.parameters():
#             param.requires_grad = False
trainable_params = [param for param in model.parameters() if param.requires_grad]
init_weights = [p.detach().clone() for p in trainable_params]


optimizer = torch.optim.Adam(trainable_params, lr=0.001)
dice_fn = util.DiceScore().to(device)
edge_weighted_dice_loss = util.EdgeWeightedDiceLoss(alpha=0.95).to(device)
focal_loss = util.BinaryFocalLoss(
    alpha=0.8,
    gamma=1.5,
)

t = time()
step = 0
for epoch in range(epochs):
    for i, (x, y, _) in enumerate(train_data):
        step += 1

        # prepare data
        aug = util.PatchAugment(no_rotate=True).to(device)
        x, y = aug(x.float().to(device)), aug(y.float().to(device))
        x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), y[:,:,y.shape[2]//2,:,:]
        x, y = augment(x, y)

        # compute output
        p_y = [torch.sigmoid(logits) for logits in model(x)]
        pred_masks = [(p > 0.5).float().detach() for p in p_y]
        masks = [m for m in model.deep_masks(y)]

        # compute loss
        dlayers = [0,1,2,3]
        flayers = [0,1,2,3]
        dloss = torch.stack([                                   # dice loss
            1 - dice_fn(p_y[i], masks[i]) if i not in [4,5]
            else edge_weighted_dice_loss(p_y[i], masks[i])
            for i in dlayers
        ])
        dloss[-1] *= 1.1
        floss = torch.stack([                                   # focal loss
            focal_loss(p_y[i], masks[i])
            for i in flayers
        ])
        # penalty_weight = 0.001
        # divergence_loss = torch.stack([                         # KL divergence loss
        #     torch.norm(p - init, 2)
        #     for p, init in zip(trainable_params, init_weights)
        # ]).mean() * penalty_weight

        loss = dloss.sum() + floss.sum() #+ divergence_loss      # total loss

        # step grad
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        optimizer.step()
        optimizer.zero_grad()

        # Logging & validation
        with torch.no_grad():
            # log loss
            loss_dict = { f'dice_{k}' : dloss[i] for i, k in enumerate(dlayers) }
            loss_dict.update({ f'focal_{k}' : floss[i] for i, k in enumerate(flayers) })
            # loss_dict['divergence'] = divergence_loss
            t_logger.add_scalars('loss', loss_dict, step)
            
            # peformance logging
            t_logger.add_scalars('acc', { f'layer_{i}' : torch.eq(pred_masks[i], masks[i]).float().mean() for i in dlayers }, step)
            t_logger.add_scalars('dice', { f'layer_{i}' : dice_fn(pred_masks[i], masks[i]) for i in dlayers }, step)
            t_logger.add_scalars('surface dice', { f'layer_{i}' : dice_fn(edge(pred_masks[i]), edge(masks[i])) for i in dlayers[-2:] }, step)

            # stats logging
            for i, (prob, pred, mask) in enumerate(zip(p_y[1:], pred_masks[1:], masks[1:])):
                t_logger.add_scalars(f'stats_{i + 1}',{
                    'prob_std': prob.std(),
                    'prob_mean': prob.mean(),
                    'pred_std': pred.std(),
                    'pred_mean': pred.mean(),
                    'mask_std': mask.std(),
                    'mask_mean': mask.mean(),
                }, step)

            # validation logging
            if (step + 1) % extra_log_every == 0:
                model.eval()

                # prepare data
                x, y, _ = next(iter(valid_data))

                x, y = x.float().to(device), y.float().to(device)
                x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), y[:,:,y.shape[2]//2,:,:]
                # x, y = augment(x, y, False)
                # x, y = x.float().to(device), y.float().to(device)
                # x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), y[:,:,y.shape[2]//2,:,:]
                # x = augment(x) # extra color augment on x

                # compute output
                p_y = [torch.sigmoid(logits) for logits in model(x)]
                pred_masks = [(p > 0.5).float() for p in p_y]
                masks = [m for m in model.deep_masks(y)]

                # peformance logging
                v_logger.add_scalars('acc', { f'layer_{i}' : torch.eq(pred_masks[i], masks[i]).float().mean() for i in dlayers }, step)
                v_logger.add_scalars('dice', { f'layer_{i}' : dice_fn(pred_masks[i], masks[i]) for i in dlayers }, step)
                v_logger.add_scalars('surface dice', { f'layer_{i}' : dice_fn(edge(pred_masks[i]), edge(masks[i])) for i in dlayers[-2:] }, step)

                # other logging
                v_logger.add_image('masks', util.mask_plots(x[:,x.shape[1]//2,:,:].unsqueeze(1), masks, pred_masks), step, dataformats='HWC')
                t_logger.add_scalar('time', time() - t, step)
                t = time()
                model.train()
                
            # Save model
            if (step + 1) % save_every == 0:
                torch.save(model.state_dict(), f'./bin/_tmp_models/unet2.5d_IN_PROGRESS.pt')

KeyboardInterrupt: 

In [10]:
log_board.clear('val-(with aug)')

In [14]:

# Test just validation

from torch.utils.data import DataLoader

extra_log_every = 50
batch_size = 32
v_logger = log_board.get_logger('val ( > 0.1)')

valid_data = DataLoader(test, batch_size=batch_size, shuffle=True)
edge = util.Edge().to(device)

model_pth = './bin/models/edge_model_trained4.pt'
model = util.UNet3P(
    in_f=1,
    layers=[32, 64, 128, 256, 512],
    block_depth=4,
    connect_depth=24,
    conv=util.nn.Conv2DNormed,
    pool_fn=nn.MaxPool2d,
    resize_kernel=(2,2),
    upsample_mode='bilinear',
    norm_fn=nn.InstanceNorm2d,
).to(device)
model.load_state_dict(torch.load(model_pth, map_location=device))
model.requires_grad_(False)
model.eval()

dice_fn = util.DiceScore().to(device)
edge_weighted_dice_loss = util.EdgeWeightedDiceLoss(alpha=0.95).to(device)
focal_loss = util.BinaryFocalLoss(
    alpha=0.8,
    gamma=1.5,
)

t = time()
step = 0
for step in range(100000):
    # Logging & validation
    with torch.no_grad():
        # prepare data
        x, y, _ = next(iter(valid_data))
        x, y = x.float().to(device), y.float().to(device)
        x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), y[:,:,y.shape[2]//2,:,:]
        x = augment(x) # extra color augment on x

        # compute output
        p_y = torch.sigmoid(model(x)[-1])
        pred_mask = (p_y > 0.1).float()
        mask = y

        # peformance logging
        v_logger.add_scalar('acc', torch.eq(pred_mask, mask).float().mean(), step)
        v_logger.add_scalar('dice', dice_fn(pred_mask, mask), step)
        v_logger.add_scalar('surface dice', dice_fn(edge(pred_mask), edge(mask)), step)

        # other logging
        if (step + 1) % extra_log_every == 0:
            v_logger.add_image('masks', util.mask_plots(x[:,x.shape[1]//2,:,:].unsqueeze(1), masks, pred_masks), step, dataformats='HWC')

KeyboardInterrupt: 