In [27]:
import random
import os
from time import time

import click
import numpy as np

import torch
from torch.utils.tensorboard import SummaryWriter

from model.IsoRSUNet import UNetModel
from model.io import save_chkpt, log_tensor
from model.loss import BinomialCrossEntropyWithLogits
from dataset.affinity import Dataset

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
def train(seed: int, training_split_ratio: float, patch_size: tuple,
          iter_start: int, iter_stop: int, output_dir: str,
          in_channels: int, out_channels: int, learning_rate: float,
          training_interval: int, validation_interval: int):

    random.seed(seed)

    writer = SummaryWriter(log_dir=os.path.join(output_dir, 'log'))

    model = UNetModel(in_channels, out_channels)
    if torch.cuda.is_available():
        model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    loss_module = BinomialCrossEntropyWithLogits()
    dataset = Dataset(
        patch_size=patch_size,
        training_split_ratio=training_split_ratio
    )

    patch_voxel_num = np.product(patch_size)
    accumulated_loss = 0.
    for iter_idx in range(iter_start, iter_stop):

        patch = dataset.random_training_patch

  
        image = np.expand_dims(patch.image, axis=0) # add dim for batch
        target = np.expand_dims(patch.target, axis=0) # add dim for batch

        image = torch.from_numpy(image)
        target = torch.from_numpy(target)

        # Transfer Data to GPU if available
        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

      
        logits = model(image)
        loss = loss_module(logits, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accumulated_loss += loss.cpu().tolist()

        if iter_idx % training_interval == 0 and iter_idx > 0:
            per_voxel_loss = accumulated_loss / training_interval / patch_voxel_num
            print(f'training loss {round(per_voxel_loss, 3)}')
            accumulated_loss = 0.
            predict = torch.sigmoid(logits)
            writer.add_scalar('Loss/train', per_voxel_loss, iter_idx)
            log_tensor(writer, 'train/image', image, iter_idx)
            # log_tensor(writer, 'train/prediction', predict, iter_idx)
            # log_tensor(writer, 'train/target', target, iter_idx)

        if iter_idx % validation_interval == 0 and iter_idx > 0:
            fname = os.path.join(output_dir, f'model_{iter_idx}.chkpt')
            print(f'save model to {fname}')
            save_chkpt(model, output_dir, iter_idx, optimizer)

            print('evaluate prediction: ')
            patch = dataset.random_validation_patch
            print('evaluation patch shape: ', patch.shape)
            validation_image = torch.from_numpy(patch.image)
            validation_target = torch.from_numpy(patch.label)
            # Transfer Data to GPU if available
            if torch.cuda.is_available():
                validation_image = validation_image.cuda()
                validation_target = validation_target.cuda()

            with torch.no_grad():
                validation_logits = UNetModel(validation_image)
                validation_predict = torch.sigmoid(validation_logits)
                validation_loss = loss_module(
                    validation_logits, validation_target)
                per_voxel_loss = validation_loss.cpu().tolist() / patch_voxel_num
                print(
                    f'iter {iter_idx}: validation loss: {round(per_voxel_loss, 3)}')
                writer.add_scalar('Loss/validation', per_voxel_loss, iter_idx)
                log_tensor(writer, 'evaluate/image',
                           validation_image, iter_idx)
                log_tensor(writer, 'evaluate/prediction',
                           validation_predict, iter_idx)
                log_tensor(writer, 'evaluate/target',
                           validation_target, iter_idx)

    writer.close()

In [35]:
seed = 7
training_split_ratio = 0.8
patch_size = (64, 64, 64)
iter_start = 0
iter_stop = 200000
output_dir = './output'
in_channels = 1
out_channels = 13
learning_rate = 0.001
training_interval = 1
validation_interval = 1000

train(seed, training_split_ratio, patch_size,
          iter_start, iter_stop, output_dir,
          in_channels, out_channels, learning_rate,
          training_interval, validation_interval)

training loss 19.594


RuntimeError: expand(torch.FloatTensor{[13, 64, 64]}, size=[64, 64]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)