In [21]:
import pandas as pd
import numpy as np
import nibabel as nib
import torch
import matplotlib.pyplot as plt

def shape_change(d_array):
    p_array=np.pad(d_array, ((5,5), (0,0), (5, 5)), 'constant')[:,12:204,:]
    return p_array

In [22]:
train_data=pd.read_csv('train.csv')
val_data=pd.read_csv('val.csv')
x_train=train_data['input_file_path']+'.nii'

In [23]:
class DoubleConv(torch.nn.Module):
    """ this function is what we are calling instead of writing in the layers each time"""
    
    def __init__(self,in_channels, out_channels):
        
        super().__init__()
        self.step=torch.nn.Sequential(torch.nn.Conv3d(in_channels,out_channels,3,padding=1),
                                     torch.nn.ReLU(),
                                      torch.nn.Conv3d(out_channels,out_channels,3,padding=1),
                                      torch.nn.ReLU()
                                        )
    def forward(self,X):
        return self.step(X)
    


In [30]:
class UNet(torch.nn.Module):
    """
    This class implements a UNet for the Segmentation
    We use 3 down- and 3 UpConvolutions and two Convolutions in each step
    """

    def __init__(self):
        """Sets up the U-Net Structure
        """
        super().__init__()
        
        
        ############# DOWN #####################
        self.layer1 = DoubleConv(1, 32)
        self.layer2 = DoubleConv(32, 64)
        self.layer3 = DoubleConv(64, 128)
        self.layer4 = DoubleConv(128, 256)

        #########################################
        
        ############## UP #######################
        self.layer5 = DoubleConv(256 + 128, 128)
        self.layer6 = DoubleConv(128+64, 64)
        self.layer7 = DoubleConv(64+32, 32)
        self.layer8 = torch.nn.Conv3d(32, 90, 1)  # Output: 90 values -> background, + Desikan Killany Atlas
        #########################################

        self.maxpool = torch.nn.MaxPool3d(2)

    def forward(self, x):
        
        ####### DownConv 1#########
        x1 = self.layer1(x)
        x1m = self.maxpool(x1)
        ###########################
        
        ####### DownConv 2#########        
        x2 = self.layer2(x1m)
        x2m = self.maxpool(x2)
        ###########################

        ####### DownConv 3#########        
        x3 = self.layer3(x2m)
        x3m = self.maxpool(x3)
        ###########################
        
        ##### Intermediate Layer ## 
        x4 = self.layer4(x3m)
        ###########################
        ####### UpCONV 1#########        
        x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4)  # Upsample with a factor of 2
        print(x5.size())
        x5 = torch.cat([x5, x3], dim=1)  # Skip-Connection
        print(x5.size())
        x5 = self.layer5(x5)
        ###########################

        ####### UpCONV 2#########        
        x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)        
        x6 = torch.cat([x6, x2], dim=1)  # Skip-Connection    
        x6 = self.layer6(x6)
        ###########################
        
        ####### UpCONV 3#########        
        x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
        x7 = torch.cat([x7, x1], dim=1)       
        x7 = self.layer7(x7)
        ###########################
        
        ####### Predicted segmentation#########        
        ret = self.layer8(x7)
        return ret,x4

# Test to see if the neural network works

In [31]:
brain0=np.asarray(nib.load(x_train[0]).get_fdata()) #Get the data from the database
brainrdy=shape_change(brain0) #Change to desired square of 192,192,192
x=torch.tensor(brainrdy) #Change to a torch tensor
x=x[None,None].float() # Get in the correct dimension and datatype
model=UNet()
with torch.no_grad():
    output,layer = model(x)
print(output.size())



torch.Size([1, 256, 48, 48, 48])
torch.Size([1, 384, 48, 48, 48])
torch.Size([1, 90, 192, 192, 192])


In [36]:
layer.view(-1,256*24*24).shape

torch.Size([24, 147456])

In [35]:
layer.flatten().shape

torch.Size([3538944])

# Now save it as model.py and move onto training

In [54]:
!curl -s https://raw.githubusercontent.com/fepegar/torchio/main/print_system.py

# flake8: noqa

import re
import sys
import platform
import torchio
import torch
import numpy
import SimpleITK as sitk


sitk_version = re.findall('SimpleITK Version: (.*?)\n', str(sitk.Version()))[0]

print('Platform:  ', platform.platform())
print('TorchIO:   ', torchio.__version__)
print('PyTorch:   ', torch.__version__)
print('SimpleITK: ', sitk_version)
print('NumPy:     ', numpy.__version__)
print('Python:    ', sys.version)
