In [1]:
%load_ext autoreload
%autoreload 2

In [60]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset 
from torch_tools.datasets import *
import numpy as np
from PIL import Image
import os
from tqdm import tqdm

In [30]:
# initialize dataloader for testing
loader = flair_dataset("./tables/nii_data_paths.csv", transform=None)
loader[0][0].shape

torch.Size([1, 1, 100, 100, 100])

In [66]:
layer = nn.Sequential( 
    nn.Conv3d(1, 5, 5, stride=5, padding=0),
    nn.BatchNorm3d(5),
    nn.ReLU(inplace=True),
)
layer2 = nn.Sequential(
    nn.Conv3d(5, 3, 3, stride=3, padding=2),
    nn.BatchNorm3d(3),
    nn.ReLU(inplace=True),
    # maxpool reduces to a 4x4x4x3 hypervolume
    nn.MaxPool3d(2),
)
x = layer(loader[[1,2]][0])
x = layer2(x)
x.shape

torch.Size([2, 3, 4, 4, 4])

In [61]:
# define neural net
class austin_net(nn.Module):
    """ 
    this class implements an extremely simple 3d conv net
    inference using .forward()
    """
    def __init__(self):
        super(austin_net, self).__init__()
        self.conv = nn.Sequential(
            # first conv outputs a 1x5x20x20x20 hypervolume
            nn.Conv3d(1, 5, 5, stride=5, padding=0),
            nn.BatchNorm3d(5),
            nn.ReLU(inplace=True),
            # 2nd conv conv outputs a 8x8x8x3 hypervolume 
            nn.Conv3d(5, 3, 3, stride=3, padding=2),
            nn.BatchNorm3d(3),
            nn.ReLU(inplace=True),
            # maxpool reduces to a 4x4x4x3 hypervolume
            nn.MaxPool3d(2),
        )
        self.fc = nn.Sequential (
            nn.Linear(192, 512),
            nn.Dropout(),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1)
        )
    
    def forward(self, x):
        out_conv1 = self.conv(x)
        out_flat = out_conv1.view(-1, 192)
        outy = self.fc(out_flat)
        
        return outy

def init_weights(m):
    if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

In [62]:
model = austin_net()
model.forward(loader[0][0])

tensor([[0.1297]], grad_fn=<AddmmBackward>)

In [63]:
# send model to device if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

# instance a criterion and optimizer 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.8)

# inialize weights
model.apply(init_weights)

model.to(device)

device: cpu


austin_net(
  (conv): Sequential(
    (0): Conv3d(1, 5, kernel_size=(5, 5, 5), stride=(5, 5, 5))
    (1): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(5, 3, kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))
    (4): BatchNorm3d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=192, out_features=512, bias=True)
    (1): Dropout(p=0.5, inplace=False)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=512, out_features=1, bias=True)
  )
)

tensor([[0.1275, 0.1275, 0.1275, 0.1275, 0.1275, 0.1275, 0.1275, 0.1275, 0.1275,
         0.1275, 0.1275, 0.1275, 0.1589, 0.1589, 0.1589, 0.1589, 0.1275, 0.1275,
         0.1275, 0.1275, 0.1275, 0.1279, 0.1275, 0.1275, 0.1275, 1.2521, 3.5860,
         0.1275, 0.1589, 4.3155, 2.4296, 0.1589, 0.1275, 0.1275, 0.1276, 0.1275,
         0.1275, 2.8150, 0.0000, 0.1275, 0.1275, 6.4556, 6.1753, 0.1275, 0.1589,
         2.0883, 6.5583, 0.1589, 0.1366, 0.1366, 0.1366, 0.1366, 0.1275, 0.1275,
         0.1275, 0.1275, 0.1275, 1.8375, 0.1298, 0.1275, 0.1589, 0.1589, 0.1589,
         0.1589, 0.2165, 0.2165, 0.2165, 0.2165, 0.2165, 0.2165, 0.2165, 0.2165,
         0.2165, 0.2165, 0.2165, 0.2165, 0.2548, 0.2548, 0.2548, 0.2548, 0.2165,
         0.2165, 0.2165, 0.2165, 0.2165, 0.2171, 0.2165, 0.2165, 0.2165, 2.4848,
         7.9245, 0.2165, 0.2548, 0.2548, 0.5428, 0.2548, 0.2165, 0.2165, 0.2165,
         0.2165, 0.2165, 2.1089, 3.5586, 0.2165, 0.2165, 1.7481, 6.9195, 0.2165,
         0.2548, 0.2548, 3.9