In [2]:
import torch, torchinfo
import torch.nn as nn
#from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
#import torchvision.transforms.functional as TF

#checking for device
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

if not torch.cuda.is_available():
  raise Exception("GPU not available. CPU training will be too slow.")

print("device name", torch.cuda.get_device_name(0))

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, (3,3,3), 1, 1, bias = False),
            #Need to be careful about the kernel size, might need to change to 3,3,3
            nn.BatchNorm3d(out_channels), #BatchNorm 
            #These parameters set it to a same convolution
            #False bias used because batchnorm
            nn.ReLU(inplace = True),
            nn.Conv3d(out_channels, out_channels, (3,3,3), 1, 1, bias = False),
            nn.BatchNorm3d(out_channels), #BatchNorm 
            #These parameters set it to a same convolution
            #False bias used because batchnorm
            nn.ReLU(inplace = True)
        )
        
    def forward(self, x):
        return self.conv(x)
    

class UNET(nn.Module):
    def __init__(
            #set to 16, 32, 64 just to test for now
        self, in_channels=1, out_channels = 3, features =[16,32,64], 
        #Output channels set to 2 such that we get 3 different classes for:
            #0 - Background
            #1 - Myocardium
            #2 - Cavity Volume
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList() #Storing Convolutional Layers for model.eval
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=(2,2,2), stride=2)
        #If image non-divisble by 2, can cause issues with concatenation
        
        #Down section of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
            
        #Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose3d(
                    feature*2, feature, kernel_size = (2,2,2), stride = 2,  #Done because adding a skip connection
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))
            
        #Bottom layer of the UNET
        self.bottom = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv3d(features[0], out_channels,kernel_size = (1,1,1))
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x) #Add to skip connections
            x = self.pool(x)
            
        x = self.bottom(x)
        skip_connections = skip_connections[::-1]
        
        for idx in range(0, len(self.ups), 2):
            #step size of 2 chosen as wanna do up, then doubleconv (counted as 2 steps)
            x = self.ups[idx](x)
            #print(x.shape)
            skip_connection = skip_connections[idx//2]
            
            if x.shape != skip_connection.shape:
               #print(skip_connection.shape[2:])
                x = F.interpolate(x, size=skip_connection.shape[2:])
                
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip) #Running through double conv
   
        return self.sigmoid(self.final_conv(x))           
        
    
def test():
    imsize = 112
    #x=torch.randn((2,1,96,96,96))
    x=torch.randn((2,1,imsize, imsize, imsize))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    #print(preds.shape)
    #print(x.shape)
    assert preds.shape == x.shape
    
    
if __name__ == "__main__":
    #test()
    pass

cuda
device name Quadro RTX 6000
