In [1]:
import torch.nn as nn
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, SubsetRandomSampler
from datasets import LUNA16DatasetFromIso, LUNA16DatasetFromCubes
from models import Classifier3D, TinyClassifier
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from pathlib import Path
import numpy as np
from losses import binary_focal_loss_with_logits
import itertools
from timeit import default_timer as dt


# linearly transform [-1000, 400] to [0, 1]
def linear_transform_to_0_1(X, min=-1000, max=400):
    result = torch.clamp(X, min=min, max=max)
    result = result - min
    result = result / (max - min)
    return result


def calc_confusion(pred_bool, target_bool):
    not_pred_bool = ~pred_bool
    not_target_bool = ~target_bool
    tp = (pred_bool & target_bool).sum().item()
    tn = (not_pred_bool & not_target_bool).sum().item()
    fp = (pred_bool & not_target_bool).sum().item()
    fn = (not_pred_bool & target_bool).sum().item()
    return tp, tn, fp, fn


def shuffle_wrapper(x):
    np.random.shuffle(x)
    return x

In [2]:
BATCH_SIZE = 256
# batch size must be even since we sample half positives, half negatives
assert BATCH_SIZE % 2 == 0

# one epoch is defined as one pass through all negative samples;
# positive samples are reused
EPOCHS = 50
MOMENTUM = 0.9
LR = 0.003
WEIGHT_DECAY = 1e-4
LOG_INTERVAL = 100
SHOULD_AUGMENT = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model = Classifier3D(in_channels=1, img_size=48).to(device)

In [4]:
from torchsummary import summary

In [5]:
train_subsets = [
    'subset0',
    'subset1',
    'subset2',
    'subset3',
    'subset4',
    'subset5',
]
val_subsets = [
    'subset6',
    'subset7',
]
test_subsets = [
    'subset8',
    'subset9',
]

train_neg_dataset = LUNA16DatasetFromIso(
    iso_root_path='/scratch/zc2357/cv/final/datasets/luna16_iso/',
    candidates_file='candidates_V2.csv',
    subsets=train_subsets,
)
train_pos_dataset = LUNA16DatasetFromCubes(
    cube_root_path='/scratch/zc2357/cv/final/datasets/luna16_cubes',
    candidates_file='candidates_V2_subindexed.csv',
    subsets=train_subsets,
)
val_dataset = LUNA16DatasetFromIso(
    iso_root_path='/scratch/zc2357/cv/final/datasets/luna16_iso/',
    candidates_file='candidates_V2.csv',
    subsets=val_subsets,
)

In [6]:
train_pos_dataloader = DataLoader(
    train_pos_dataset,
    batch_size=BATCH_SIZE//2,
    sampler=SubsetRandomSampler(train_pos_dataset.pos_sample_idx),
    num_workers=1,
    drop_last=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # validation set uses cached arrays to save disk hits
    num_workers=1,
)

In [7]:
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
writer = SummaryWriter(comment='_profiling_baseline_UNet3D3x3_randomFlipsPosNeg_focalLoss_weightDecay1e-4')

In [8]:
# =========== TRAINING ===========
pos_epoch = 1  # how many times we've gone through the positive samples
for epoch in range(1, EPOCHS+1):
    # reshuffle negative dataloader every epoch
    neg_idx_shuffled = (
        train_neg_dataset.candidates.loc[train_neg_dataset.neg_sample_idx]
        .copy().reset_index()
        .groupby('seriesuid')['index'].unique()
        .apply(shuffle_wrapper)  # shuffle within cases
        .apply(list)
    )
    neg_idx_shuffled = neg_idx_shuffled.sample(len(neg_idx_shuffled))      # shuffle case order
    neg_idx_shuffled = list(itertools.chain.from_iterable(neg_idx_shuffled.values))  # flatten

    train_neg_dataloader = DataLoader(
        train_neg_dataset,
        batch_size=BATCH_SIZE//2,
        shuffle=False,
        sampler=neg_idx_shuffled,
        num_workers=1,
    )

    model.train()
    print('Epoch %s' % epoch)
    train_pos_dataiter = iter(train_pos_dataloader)
    train_loss_mean = 0
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    start = dt()
    for batch_idx, (neg_X, neg_y) in enumerate(train_neg_dataloader):
        print('neg: %.2f' % (dt() - start))
        start = dt()
        optimizer.zero_grad()
        try:
            should_write = False
            pos_X, pos_y = next(train_pos_dataiter)
        except StopIteration:
            train_pos_dataiter = iter(train_pos_dataloader)
            pos_X, pos_y = next(train_pos_dataiter)
            
            pos_epoch += 1
            should_write = True
        print('pos: %.2f' % (dt() - start))
        start = dt()

        neg_X = neg_X.reshape(-1, 1, 48, 48, 48)
        train_X = torch.cat([pos_X, neg_X])
        train_y_cpu = torch.cat([pos_y, neg_y]).reshape(-1, 1).float()

        train_X = linear_transform_to_0_1(train_X, min=-1000, max=400)
        
        if SHOULD_AUGMENT:
            flip_dims = []
            flip_x = (np.random.randint(0, 2) == 1)
            flip_y = (np.random.randint(0, 2) == 1)
            flip_z = (np.random.randint(0, 2) == 1)

            if flip_x:
                flip_dims.append(2)
            if flip_y:
                flip_dims.append(3)
            if flip_z:
                flip_dims.append(4)
            if len(flip_dims) > 0:
                train_X = torch.flip(train_X, flip_dims)
        
        print('aug/proc: %.2f' % (dt() - start))
        start = dt()
        
        train_X = train_X.to(device)
        train_y = train_y_cpu.to(device)
        
        pred_y = model(train_X)
        
        print('forward: %.2f' % (dt() - start))
        start = dt()

        loss = binary_focal_loss_with_logits(pred_y, train_y, reduction='mean')
        loss.backward()
        optimizer.step()
        
        print('backward: %.2f' % (dt() - start))
        start = dt()
        
        train_loss_mean += loss.sum().cpu().item()
        
        print('train_loss_mean: %.2f' % (dt() - start))
        start = dt()
        
        pred_y_bool = (torch.sigmoid(pred_y.cpu()) > 0.5).cpu()
        
        print('pred_y_bool: %.2f' % (dt() - start))
        start = dt()
        
        train_y_bool = (train_y_cpu == 1)
        
        print('train_y_bool: %.2f' % (dt() - start))
        start = dt()
        
        this_tp, this_tn, this_fp, this_fn = calc_confusion(pred_y_bool, train_y_bool)
        
        print('confusion: %.2f' % (dt() - start))
        start = dt()
        
        tp += this_tp
        tn += this_tn
        fp += this_fp
        fn += this_fn
        
        print('confusion_add: %.2f' % (dt() - start))
        start = dt()
        
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t{:.0f}\t{:.0f}\t{:.0f}\t{:.0f}'.format(
                epoch, batch_idx * len(neg_X), len(train_neg_dataloader.dataset),
                100. * batch_idx / len(train_neg_dataloader), loss.item(), this_tp, this_tn, this_fp, this_fn))

        if should_write:
            writer.add_scalar('loss/train', train_loss_mean / len(train_pos_dataloader), pos_epoch)
            writer.add_scalar('accuracy/train', 100. * (tp + tn) / (tp + fp + tn + fn), pos_epoch)
            try:
                writer.add_scalar('precision/train', 100. * (tp / (tp + fp)), pos_epoch)
            except ZeroDivisionError:
                writer.add_scalar('precision/train', -1, pos_epoch)
            try:
                writer.add_scalar('recall/train', 100. * (tp / (tp + fn)), pos_epoch)
            except ZeroDivisionError:
                writer.add_scalar('recall/train', -1, pos_epoch)
            train_loss_mean = 0
            tp = 0
            tn = 0
            fp = 0
            fn = 0
            
        print('misc: %.2f' % (dt() - start))
        start = dt()
    
    # VALIDATION
    print('Validation')
    model.eval()
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for batch_idx, (X, y) in enumerate(val_dataloader):
        X = X.reshape(-1, 1, 48, 48, 48).float()
        y = y.reshape(-1, 1).float()
        X = X.to(device)
        y = y.to(device)
        pred_y = model(X)
        
        pred_y_bool = torch.sigmoid(pred_y) > 0.5
        y_bool = (y == 1)
        this_tp, this_tn, this_fp, this_fn = calc_confusion(pred_y_bool, y_bool)
        
        tp += this_tp
        tn += this_tn
        fp += this_fp
        fn += this_fn
    
    writer.add_scalar('accuracy/val', 100. * (tp + tn) / (tp + fp + tn + fn), epoch)
    writer.add_scalar('tp/val', tp, epoch)
    writer.add_scalar('fp/val', fp, epoch)
    writer.add_scalar('tn/val', tn, epoch)
    writer.add_scalar('fn/val', fn, epoch)
    try:
        writer.add_scalar('precision/val', 100. * (tp / (tp + fp)), epoch)
    except ZeroDivisionError:
        writer.add_scalar('precision/val', -1, epoch)
    try:
        writer.add_scalar('recall/val', 100. * (tp / (tp + fn)), epoch)
    except ZeroDivisionError:
        writer.add_scalar('recall/val', -1, epoch)
    
    model_savepath = (Path(writer.get_logdir()) / f'epoch_{epoch}.pth').as_posix()
    torch.save(model.state_dict(), model_savepath)

Epoch 1
neg: 0.55
pos: 0.00
aug/proc: 0.15
forward: 1.10
backward: 0.03
train_loss_mean: 0.41
pred_y_bool: 0.00
train_y_bool: 0.00
confusion: 0.00
confusion_add: 0.00
misc: 0.00
neg: 0.00
pos: 0.00
aug/proc: 0.14
forward: 0.03
backward: 0.01
train_loss_mean: 0.60
pred_y_bool: 0.00
train_y_bool: 0.00
confusion: 0.00
confusion_add: 0.00
misc: 0.00
neg: 0.00
pos: 0.00
aug/proc: 0.13
forward: 0.03
backward: 0.01
train_loss_mean: 0.60
pred_y_bool: 0.00
train_y_bool: 0.00
confusion: 0.00
confusion_add: 0.00
misc: 0.00
neg: 0.00
pos: 0.00
aug/proc: 0.14
forward: 0.03
backward: 0.01
train_loss_mean: 0.60
pred_y_bool: 0.00
train_y_bool: 0.00
confusion: 0.00
confusion_add: 0.00
misc: 0.00
neg: 0.00
pos: 0.00
aug/proc: 0.11
forward: 0.03
backward: 0.01
train_loss_mean: 0.60
pred_y_bool: 0.00
train_y_bool: 0.00
confusion: 0.00
confusion_add: 0.00
misc: 0.00
neg: 0.00
pos: 0.00
aug/proc: 0.14
forward: 0.03
backward: 0.01
train_loss_mean: 0.60
pred_y_bool: 0.00
train_y_bool: 0.00
confusion: 0.00
con

KeyboardInterrupt: 