In [1]:
%matplotlib inline

import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import scipy.ndimage as ndimage
import torch

from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

from model import OctNet

In [2]:
PROCESSED_DATA_PATH = '../oct-data/data-1'

In [3]:
class OctDataset(Dataset):
    def __init__(self):
        pos_dir = os.path.join(PROCESSED_DATA_PATH, 'pos')
        pos_paths = [os.path.join(pos_dir, f) for f in os.listdir(pos_dir)]
        
        neg_dir = os.path.join(PROCESSED_DATA_PATH, 'neg')
        neg_paths = [os.path.join(neg_dir, f) for f in os.listdir(neg_dir)]
        
        self.cube_paths = pos_paths + neg_paths
        self.labels = [1] * len(pos_paths) + [0] * len(neg_paths)
        
        self.transforms = T.Compose([
            T.Resize([200, 200]),
            T.ToTensor()
        ])
        
        assert len(self.labels) == len(self.cube_paths)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, i):
        cube = np.load(self.cube_paths[i])
        label = self.labels[i]
        
        slice_ = cube[:, :, np.random.randint(cube.shape[-1])]
        img = Image.fromarray(slice_)
        return torch.FloatTensor(self.transforms(img)), label

In [4]:
dataset = OctDataset()

In [5]:
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8)

In [6]:
for img, label in loader:
    print(img, label)
    break


( 0 , 0 ,.,.) = 
  0.0863  0.1020  0.1059  ...   0.0824  0.1020  0.0784
  0.0627  0.0667  0.0353  ...   0.0902  0.0980  0.0941
  0.1333  0.1216  0.1255  ...   0.1216  0.1020  0.0784
           ...             ⋱             ...          
  0.0980  0.1137  0.0784  ...   0.0667  0.0667  0.1098
  0.0941  0.0824  0.1255  ...   0.1137  0.0824  0.1020
  0.0863  0.0784  0.0980  ...   0.0941  0.0784  0.0863
      ⋮  

( 1 , 0 ,.,.) = 
  0.0549  0.1098  0.1098  ...   0.0863  0.0863  0.0980
  0.0235  0.0863  0.0863  ...   0.1176  0.0941  0.1137
  0.0275  0.1098  0.0980  ...   0.1020  0.1137  0.1490
           ...             ⋱             ...          
  0.0196  0.1373  0.1294  ...   0.1020  0.0980  0.1333
  0.0118  0.0941  0.0667  ...   0.0941  0.0902  0.0941
  0.0235  0.0941  0.0980  ...   0.1020  0.1059  0.1412
      ⋮  

( 2 , 0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0824  0.0863  0.1137
  0.0000  0.0000  0.0039  ...   0.0627  0.0706  0.1137
  0.0000  0.0000  0.0039  ...   0.1255  0.0941

In [None]:
net = OctNet().cuda()

for epoch in range(200):
    print(f'Epoch {epoch}')
    for X, y in loader:
        y = y.float()
        X_var = Variable(X, requires_grad=False).cuda()
        y_var = Variable(y, requires_grad=False).cuda()
        loss = net.train_step(X_var, y_var)
        print(f'BCE loss = {loss.data[0]}')

Epoch 0
