In [None]:
%matplotlib inline

import os
import random
import statistics
import time

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import seaborn as sns
from sklearn.metrics import average_precision_score, confusion_matrix, roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

import constants

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class OctSliceDataset(Dataset):
    def __init__(self, data_dir, slice_min=80, slice_max=120):
        assert 0 <= slice_min <= slice_max < 200
        self.slice_min = slice_min
        self.slice_max = slice_max
        
        pos_dir = os.path.join(data_dir, 'pos')
        pos_paths = [os.path.join(pos_dir, f) for f in os.listdir(pos_dir)]
        
        neg_dir = os.path.join(data_dir, 'neg')
        neg_paths = [os.path.join(neg_dir, f) for f in os.listdir(neg_dir)]
        
        self.cube_paths = pos_paths + neg_paths
        self.labels = [1.] * len(pos_paths) + [0.] * len(neg_paths)
        
        self.transforms = T.Compose([
            T.Resize([200, 200]),
            T.ToTensor()
        ])
        
        assert len(self.labels) == len(self.cube_paths)
        print(f'Number of cubes: {len(self.labels)}')
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, i):
        cube = np.load(self.cube_paths[i])
        label = self.labels[i]
        
        slice_idx = random.randint(self.slice_min, self.slice_max)
        
        slice_ = cube[:, :, slice_idx]
        img = Image.fromarray(slice_)
        return torch.tensor(self.transforms(img)), torch.tensor(label)

In [None]:
class Flatten(nn.Module):
    def forward(self, x):
        N, C, H, W = x.size()
        return x.view(N, -1)

class OctSliceNet(nn.Module):
    def __init__(self):
        super(OctSliceNet, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.BatchNorm2d(num_features=16),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 16, kernel_size=3),
            nn.BatchNorm2d(num_features=16),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2),
            Flatten(),
            nn.Linear(8464, 1)
        )
        self.optimizer = torch.optim.Adam(self.parameters())
        self.loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, slice_):
        return self.cnn(slice_).squeeze(dim=1)
    
    def train_step(self, slice_, targets):
        logits = self(slice_)
        loss = self.loss_fn(logits, targets)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss

In [None]:
def evaluate(net, loader, verbose=True):
    start = time.time()
    all_logits = []
    all_labels = []
    all_losses = []

    with torch.no_grad():
        for i, (X, y) in enumerate(loader):
            logits = net(X.to(device))
            loss = net.loss_fn(logits, y.to(device))
            all_logits.extend(list(logits.cpu().numpy()))
            all_labels.extend(list(y))
            all_losses.append(loss.item())

    val_loss = statistics.mean(all_losses)
    auprc = average_precision_score(all_labels, all_logits)
    auroc = roc_auc_score(all_labels, all_logits)
    
    if verbose:
        print(f'Average precision score: {auprc}')
        print(f'AUROC: {auroc}')
        print(f'Validation loss (approximate): {val_loss}')
        print(f'Elapsed: {time.time() - start}')
    return val_loss, auprc, auroc

In [None]:
train_dir = os.path.join(constants.PROCESSED_DATA_PATH, 'train')
val_dir = os.path.join(constants.PROCESSED_DATA_PATH, 'val')
test_dir = os.path.join(constants.PROCESSED_DATA_PATH, 'test')

def train_with_slice(slice_idx, num_epochs=10, verbose=True):
    print('==============')
    print(f'Training with slice index: {slice_idx}')
    
    train_dataset = OctSliceDataset(train_dir, slice_min=slice_idx, slice_max=slice_idx)
    val_dataset = OctSliceDataset(val_dir, slice_min=slice_idx, slice_max=slice_idx)
    test_dataset = OctSliceDataset(test_dir, slice_min=slice_idx, slice_max=slice_idx)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8)

    net = OctSliceNet().to(device)

    if verbose: print('------ Evaluating ------')
    evaluate(net, val_loader, verbose)
    for epoch in range(1, num_epochs + 1):
        if verbose: print(f'====== Epoch {epoch} ======')
        losses = []
        for X, y in train_loader:
            loss = net.train_step(X.to(device), y.to(device))
            loss = loss.item()
            losses.append(loss)
        train_loss = statistics.mean(losses)
        if verbose: print(f'Train loss (approximate): {train_loss}')

        if verbose: print('------ Evaluating ------')
        val_loss, auprc, auroc = evaluate(net, val_loader, verbose)
        
    return train_loss, val_loss, auprc, auroc

In [None]:
slice_idxs = [0, 25, 50, 75, 100, 125, 150, 175, 199]
train_losses = []
val_losses = []
auprcs = []
aurocs = []

for slice_idx in slice_idxs:
    train_loss, val_loss, auprc, auroc = train_with_slice(slice_idx)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    auprcs.append(auprc)
    aurocs.append(auroc)

In [None]:
plt.scatter(slice_idxs, train_losses)

In [None]:
plt.scatter(slice_idxs, val_losses)

In [None]:
plt.scatter(slice_idxs, auprcs)

In [None]:
plt.scatter(slice_idxs, aurocs)