In [1]:
from time import time

import sys
sys.path.append('..')
import util

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
patch_size = 3,64,64
train = util.data.SenNet(
    patch_size,
    guarantee_vessel=0.5,
    data=[
        [
            torch.load('/root/data/cache/train/kidney_1_dense/images.pt'),
            torch.load('/root/data/cache/train/kidney_2/images.pt'),
            torch.load('/root/data/cache/train/kidney_3_sparse/images.pt'),
        ],
        [
            torch.load('/root/data/cache/train/kidney_1_dense/images.pt'),
            torch.load('/root/data/cache/train/kidney_2/images.pt'),
            torch.load('/root/data/cache/train/kidney_3_sparse/images.pt'),
        ]
    ]
)

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0128 01:18:07.434986 140359506990272 program.py:298] TensorBoard could not bind to port 6005, it was already in use
ERROR: TensorBoard could not bind to port 6005, it was already in use


In [3]:
log_board.clear('train')
log_board.clear('val')

In [4]:
batch_size = 12
train_data = DataLoader(train, batch_size=batch_size, shuffle=True)


In [5]:
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image

def mask_plots(targets: list[torch.Tensor], preds: list[torch.Tensor]):
    fig, axes = plt.subplots(2, len(targets), figsize=(20, 10))
    if len(targets) == 1:
        axes = axes[:, None]
    
    for i, (mask, pmask) in enumerate(zip(targets, preds)):
        mask = mask[0].unsqueeze(0)
        pmask = pmask[0].unsqueeze(0)
        util.Display(scan=mask.cpu())._view_slice(0, 0, axes[0][i])
        util.Display(scan=pmask.cpu())._view_slice(0, 0, axes[1][i])

    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    plt.close()
    return np.array(Image.open(buf))

In [6]:
from torch.utils.data import DataLoader
import torchvision.transforms as T

save_every = 500
valid_every = 25
epochs = 800
batch_size = 128
t_logger = log_board.get_logger('train')
v_logger = log_board.get_logger('val')

train_data = DataLoader(train, batch_size=batch_size, shuffle=True)

model = util.UNet3P(
    in_f=3,
    layers=[32, 64, 64, 128, 128, 128, 128],
    block_depth=6,
    connect_depth=8,
    conv=util.nn.Conv2DNormed,
    pool_fn=nn.MaxPool2d,
    resize_kernel=(2,2),
    upsample_mode='bilinear',
    norm_fn=nn.BatchNorm2d,
    dropout=(nn.Dropout2d, 0.1)
).to(device)
# model.load_state_dict(torch.load('./bin/_tmp_models/unet2.5d_IN_PROGRESS.pt', map_location=device))
model.train()


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
dice_fn = util.DiceScore().to(device)
focal_loss = util.BinaryFocalLoss(
    alpha=0.8,
    gamma=1.5,
)
loss_fn = nn.MSELoss()
rand_erase = T.RandomErasing(p=1)

t = time()
step = 0
for epoch in range(epochs):
    for i, (x, y, _) in enumerate(train_data):
        step += 1

        # Prepare data
        aug = util.PatchAugment(no_rotate=True).to(device)
        x, y = aug(x.float().to(device)), aug(y.float().to(device))
        x, y = x.view(x.shape[0],x.shape[1]*x.shape[2],x.shape[3],x.shape[4]), (y[:,:,y.shape[2]//2,:,:] / 255) - 0.5
        x = rand_erase(x)
        
        # Compute output
        preds = model(x)
        masks = [m for m in model.deep_masks(y)]

        # calc loss
        loss = torch.stack([
            loss_fn(preds[i], masks[i])
            for i in range(len(preds))
        ])

        # Step grad
        loss.sum().backward()
        optimizer.step()
        optimizer.zero_grad()

        # Logging & validation
        with torch.no_grad():
            # log loss
            loss_dict = { f'_{i}' : loss[i] for i, k in enumerate(loss) }
            t_logger.add_scalars('loss', loss_dict, step)
            
            # stats
            for i, (pred, mask) in enumerate(zip(preds, masks)):
                t_logger.add_scalars(f'stats_{i + 1}',{
                    'pred_std': pred.std(),
                    'pred_mean': pred.mean(),
                    'mask_std': mask.std(),
                    'mask_mean': mask.mean(),
                }, step)

            # Validation logging
            if (step + 1) % valid_every == 0:
                v_logger.add_image('masks', mask_plots(masks, preds), step, dataformats='HWC')
                t_logger.add_scalar('time', time() - t, step)
                t = time()
                model.train()
                
            # Save model
            if (step + 1) % save_every == 0:
                torch.save(model.state_dict(), f'./bin/_tmp_models/unet2.5d_IN_PROGRESS.pt')

: 