In [5]:
import numpy as np
import csv
import os
import matplotlib.pyplot as plt
%matplotlib inline 
import re
import pandas as pd
import gc
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as D

In [6]:
'''
Begin training/evaluation of model.

Questions: 
    - 1) How do I resize images to have same depths to normalize? 
        Final layer can just be resize to original size, and one-hot.
    - 2) How do I deal with multiple outputs? 
        Same.
    - 3) How should I format final layer to reflect nature of outputs?
        Same.
    - 4) Am I using 3d layers incorrectly? Switch to 2d? 
'''

class UNetDBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(UNetDBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_c, out_c, kernel_size = 3, padding = 1)
        self.conv2 = nn.Conv3d(out_c, out_c, kernel_size = 3, padding = 1)
        self.batch = nn.BatchNorm3d(out_c)
        self.pool = nn.MaxPool3d(2)
    def forward(self, x): 
        x = F.leaky_relu(conv1(x))
        x = batch(x)
        x = F.leaky_relu(conv2(x))
        x = batch(x)
        x = pool(x)
        return x

class UNetUBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(UNetUBlock, self).__init__()
        self.convt1 = nn.ConvTranspose3d(in_c, in_c, 2, stride = 2)
        self.conv1 = nn.Conv3d(in_c, out_c, kernel_size = 3, padding = 1)
        self.batch = nn.BatchNorm3d(out_c)
        self.conv2 = nn.Conv3d(out_c, out_c, kernel_size = 3, padding = 1)
    def forward(self, x): 
        x = F.leaky_relu(conv1(x))
        x = batch(x)
        x = F.leaky_relu(conv2(x))
        x = batch(x)
        return x
    
class UNetBBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(UNetBBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_c, out_c, kernel_size = 3, padding = 1)
        self.conv2 = nn.Conv3d(out_c, out_c, kernel_size = 3, padding = 1)
        self.batch = nn.BatchNorm3d(out_c)
        self.dropout = nn.Dropout3d()
    def forward(self, x):
        x = conv1(x)
        x = batch(x)
        x = dropout(x)
        x = conv2(x)
        x = batch(x)
        x = dropout(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        initf = 32
        self.d1 = UNetDBlock(1, initf)
        self.d2 = UNetDBlock(initf, initf * 2)
        self.d3 = UNetDBlock(initf * 2, initf * 4)
        self.d4 = UNetDBlock(initf * 4, initf * 8)
        self.b = UNetBBlock(initf * 8, initf * 16)
        self.u1 = UNetUBlock(initf * 16, initf * 8)
        self.u2 = UNetUBlock(initf * 8, initf * 4)
        self.u3 = UNetUBlock(initf * 4, initf * 2)
        self.u4 = UNetUBlock(initf * 2, initf)
        self.final = nn.Conv3d(initf, 1, kernel_size = 1)
        self.final_sig = nn.Sigmoid()
    def forward(self, x): # D * H * W
        x = d1.forward(x) # D * 512 * 512 -> D/2 * 256 * 256
        x = d2.forward(x) # D/2 * 256 * 256 -> D/4 * 128 * 128
        x = d3.forward(x) # D/4 * 128 * 128 -> D/8 * 64 * 64
        x = d4.forward(x) # D/8 * 64 * 64 -> D/16 * 32 * 32
        
        x = b.forward(x) # D/16 * 32 * 32 -> D/16 * 32 * 32
        
        x = u1.forward(x) # D/16 * 32 * 32 -> D/8 * 64 * 64
        x = u2.forward(x) # D/8 * 64 * 64 -> D/4 * 128 * 128
        x = u3.forward(x) # D/4 * 128 * 128 -> D/2 * 256 * 256
        x = u4.forward(x) # D/2 * 256 * 256 -> D * 512 * 512
        
        x = final(x) # D * 512 * 512 -> D * 512 * 512 (1 channel)
        x = final_sig(x) # D * 512 * 512 -> D * 512 * 512
        return x