In [1]:
%matplotlib inline

import os
import statistics

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import scipy.ndimage as ndimage
import torch

from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

from model import OctNet

In [2]:
PROCESSED_DATA_PATH = '../oct-data/data-2'

In [3]:
class OctDataset(Dataset):
    def __init__(self, split):
        pos_dir = os.path.join(PROCESSED_DATA_PATH, split, 'pos')
        pos_paths = [os.path.join(pos_dir, f) for f in os.listdir(pos_dir)]
        
        neg_dir = os.path.join(PROCESSED_DATA_PATH, split, 'neg')
        neg_paths = [os.path.join(neg_dir, f) for f in os.listdir(neg_dir)]
        
        self.cube_paths = pos_paths + neg_paths
        self.labels = [1] * len(pos_paths) + [0] * len(neg_paths)
        
        self.transforms = T.Compose([
            T.ToTensor()
        ])
        
        assert len(self.labels) == len(self.cube_paths)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, i):
        slice_ = np.load(self.cube_paths[i])
        label = self.labels[i]
        img = Image.fromarray(slice_)
        return torch.FloatTensor(self.transforms(img)), label

In [4]:
train_dataset = OctDataset('train')
val_dataset = OctDataset('val')
test_dataset = OctDataset('test')

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=8)

In [5]:
net = OctNet().cuda()

for epoch in range(200):
    print(f'Epoch {epoch}')
    losses = []
    for X, y in train_loader:
        y = y.float()
        X_var = Variable(X, requires_grad=False).cuda()
        y_var = Variable(y, requires_grad=False).cuda()
        loss = net.train_step(X_var, y_var)
        # print(f'BCE loss = {loss.data[0]}')
        losses.append(loss.data[0])
    avg_loss = statistics.mean(losses)
    print(f'Avg BCE loss = {avg_loss}')
    
    num_correct, num_total = 0, 0
    for X, y in val_loader:
        preds = (net(Variable(X, volatile=True).cuda()) > 0.5).data.cpu().numpy()
        y = y.numpy()
        num_correct += sum(preds == y)
        num_total += len(y)
    print(f'Val accuracy: {num_correct} of {num_total} | {num_correct / num_total:.2f}')

Epoch 0
Avg BCE loss = 1.3795386367485303
Val accuracy: 33 of 65 | 0.51
Epoch 1
Avg BCE loss = 0.9116366449273935
Val accuracy: 23 of 65 | 0.35
Epoch 2
Avg BCE loss = 0.8483347982919517
Val accuracy: 38 of 65 | 0.58
Epoch 3
Avg BCE loss = 0.7111889717458677
Val accuracy: 34 of 65 | 0.52
Epoch 4
Avg BCE loss = 0.5373851070884897
Val accuracy: 31 of 65 | 0.48
Epoch 5
Avg BCE loss = 0.4469903622980879
Val accuracy: 36 of 65 | 0.55
Epoch 6
Avg BCE loss = 0.3753434031137398
Val accuracy: 37 of 65 | 0.57
Epoch 7
Avg BCE loss = 0.29712805577686857
Val accuracy: 28 of 65 | 0.43
Epoch 8
Avg BCE loss = 0.2972377981020122
Val accuracy: 33 of 65 | 0.51
Epoch 9
Avg BCE loss = 0.11591688241465252
Val accuracy: 33 of 65 | 0.51
Epoch 10
Avg BCE loss = 0.058581971500788914
Val accuracy: 32 of 65 | 0.49
Epoch 11
Avg BCE loss = 0.02728327382121141
Val accuracy: 35 of 65 | 0.54
Epoch 12
Avg BCE loss = 0.016628504063574082
Val accuracy: 34 of 65 | 0.52
Epoch 13
Avg BCE loss = 0.012246684705400291
Val accur

Avg BCE loss = 2.2024840266284696e-06
Val accuracy: 38 of 65 | 0.58
Epoch 109
Avg BCE loss = 2.0485903742523453e-06
Val accuracy: 39 of 65 | 0.60
Epoch 110
Avg BCE loss = 2.2861279290247863e-06
Val accuracy: 35 of 65 | 0.54
Epoch 111
Avg BCE loss = 2.122844666901609e-06
Val accuracy: 35 of 65 | 0.54
Epoch 112
Avg BCE loss = 2.505380750876191e-06
Val accuracy: 38 of 65 | 0.58
Epoch 113
Avg BCE loss = 1.973457602509691e-06
Val accuracy: 36 of 65 | 0.55
Epoch 114
Avg BCE loss = 1.976202954302462e-06
Val accuracy: 35 of 65 | 0.54
Epoch 115
Avg BCE loss = 0.685965531920738
Val accuracy: 36 of 65 | 0.55
Epoch 116
Avg BCE loss = 0.16136878742953809
Val accuracy: 34 of 65 | 0.52
Epoch 117
Avg BCE loss = 0.030150328193470334
Val accuracy: 33 of 65 | 0.51
Epoch 118
Avg BCE loss = 0.011740373094378337
Val accuracy: 33 of 65 | 0.51
Epoch 119
Avg BCE loss = 0.002669982588654306
Val accuracy: 34 of 65 | 0.52
Epoch 120
Avg BCE loss = 0.0020143606637270943
Val accuracy: 34 of 65 | 0.52
Epoch 121
Avg B