# Preliminaries

Mount the drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Import libraries

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits import mplot3d
import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)
import torch 
import torch.nn as nn
import time
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader
import pickle

!pip install torchio
import torchio as tio

device='cuda' if torch.cuda.is_available() else 'cpu'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchio
  Downloading torchio-0.18.90-py2.py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.7/172.7 KB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
Collecting SimpleITK!=2.0.*,!=2.1.1.1
  Downloading SimpleITK-2.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Collecting colorama<0.5.0,>=0.4.3
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting shellingham<2.0.0,>=1.3.0
  Downloading shellingham-1.5.0.post1-py2.py3-none-any.whl (9.4 kB)
Collecting rich<13.0.0,>=10.11.0
  Downloading rich-12.6.0-py3-none-any.whl (237 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m237.5/237.5

In [3]:
print(device)

cuda


Load Data

In [4]:
X_Guys=np.load('/content/drive/MyDrive/healthcare data/X_Guys.npy')
X_HH=np.load('/content/drive/MyDrive/healthcare data/X_HH.npy')
X_IOP=np.load('/content/drive/MyDrive/healthcare data/X_IOP.npy')

y_Guys=np.load('/content/drive/MyDrive/healthcare data/y_Guys.npy')
y_HH=np.load('/content/drive/MyDrive/healthcare data/y_HH.npy')
y_IOP=np.load('/content/drive/MyDrive/healthcare data/y_IOP.npy')

Split into training, and testing sets

In [5]:
training_prop=0.8

#Randomly shuffle the sets

np.random.seed(2718281828)
permutation_Guys=np.random.permutation(len(y_Guys))
permutation_HH=np.random.permutation(len(y_HH))
permutation_IOP=np.random.permutation(len(y_IOP))

X_Guys=X_Guys[permutation_Guys]
X_HH=X_HH[permutation_HH]
X_IOP=X_IOP[permutation_IOP]

y_Guys=y_Guys[permutation_Guys]
y_HH=y_HH[permutation_HH]
y_IOP=y_IOP[permutation_IOP]

#Take subsets so that there are 100 images with an equal number from each source

n=33

X_Guys=X_Guys[:n+1]
X_HH=X_HH[:n]
X_IOP=X_IOP[:n]

y_Guys=y_Guys[:n+1]
y_HH=y_HH[:n]
y_IOP=y_IOP[:n]

#Separate into training and test

X_train=[X[:int(len(X)*training_prop**2)] for X in [X_Guys,X_HH,X_IOP]]
y_train=[y[:int(len(y)*training_prop**2)] for y in [y_Guys,y_HH,y_IOP]]

X_val=[X[int(len(X)*training_prop**2):int(len(X)*training_prop)] for X in [X_Guys,X_HH,X_IOP]]
y_val=[y[int(len(y)*training_prop**2):int(len(y)*training_prop)] for y in [y_Guys,y_HH,y_IOP]]

X_test=[X[int(len(X)*training_prop):] for X in [X_Guys,X_HH,X_IOP]]
y_test=[y[int(len(y)*training_prop):] for y in [y_Guys,y_HH,y_IOP]]

In [6]:
del X_Guys
del X_HH
del X_IOP
del y_Guys
del y_HH
del y_IOP

# Data Pre-Processing

Centre the data separately for each source (to handle the different intensities generated by the different machines) and reshape to the pytorch convention \begin{equation}Num Samples\times Num Channels \times Height\times Width \times Depth
\end{equation}

In [7]:
def centring(X):
    epsilon = 1e-7 # To prevent division by 0
    mean=np.mean(X)
    std=np.std(X)+epsilon
    X=(X-mean)/std
    return X

X_train=np.concatenate([centring(X) for X in X_train])
X_val=np.concatenate([centring(X) for X in X_val])
X_test=np.concatenate([centring(X) for X in X_test])

X_train=X_train.reshape((len(X_train),1,40,128,128))
X_val=X_val.reshape((len(X_val),1,40,128,128))
X_test=X_test.reshape((len(X_test),1,40,128,128))

Reshape the labels

In [8]:
y_train=np.concatenate(y_train)
y_val=np.concatenate(y_val)
y_test=np.concatenate(y_test)

y_train = np.concatenate(y_train).reshape(len(y_train),1,40,128,128)
y_val = np.concatenate(y_val).reshape(len(y_val),1,40,128,128)
y_test = np.concatenate(y_test).reshape(len(y_test),1,40,128,128)

Define the random data augmentation transformation

In [9]:
transform=tio.Compose([tio.RandomBiasField(coefficients=0.1),tio.RandomBlur(std=1),tio.RandomNoise(),tio.RandomGamma()])

Define a dataset and create the data loaders

In [10]:
class base_dataset(Dataset):
    def __init__(self, data,target):
        self.data = torch.Tensor(data)
        self.target = torch.Tensor(target)

    def __getitem__(self, index):
        x = transform(self.data[index]).to(device)
        y = self.target[index].to(device)
        return x, y

    def __len__(self):
        return len(self.data)
    
train_dataset = base_dataset(X_train,y_train)
val_dataset = base_dataset(X_val,y_val)
test_dataset = base_dataset(X_test,y_test)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [11]:
def id(x):
    return x.repeat(1,2,1,1,1)

class discriminator_dataset(Dataset):
    def __init__(self,data,segmentation_target,segmentation_model=id):
        self.data=torch.Tensor(data)
        self.segmentation_target=torch.Tensor(segmentation_target)
        self.segmentation_model=segmentation_model

    def __getitem__(self,index):
        X=transform(self.data[index]).to(device)
        segmentation_target=self.segmentation_target[index].to(device)
        true=torch.cat((X,segmentation_target),0)
        segmentation_pred=self.segmentation_model(X.unsqueeze(0))[0]
        false=torch.cat((X,segmentation_pred),0)
        return true,false

    def __len__(self):
        return len(self.data)

    def set_model(self,model):
        self.segmentation_model=model

class DiscriminatorDataLoader(DataLoader):
    def set_model(self,model):
        self.dataset.set_model(model)

train_discriminator_dataset = discriminator_dataset(X_train,y_train)

train_discriminator_dataloader = DiscriminatorDataLoader(train_discriminator_dataset, batch_size=8, shuffle=True, drop_last=True)

In [12]:
del X_train
del X_val
del X_test
del y_train
del y_val
del y_test

# Slice-wise model

Below is various transformation functions used to convert a list of mri images to a list of slices in each of the axial, sagittal and coronal planes

In [13]:
def to_sagittal(x):
    n,c,h,w,d=x.shape
    return torch.stack([x[i,:,j,:,:] for i in range(n) for j in range(h)]),n

def from_sagittal(x,n):
    if x.dim()==4:
        t,c,w,d=x.shape
        h=t//n
        return torch.stack([torch.stack([x[j*h+i,:,:,:] for i in range(h)],1) for j in range(n)])
    else:
        t,c=x.shape
        h=t//n
        return torch.stack([torch.stack([x[j*h+i,:] for i in range(h)]).mean(0) for j in range(n)])

def to_axial(x):
    n,c,h,w,d=x.shape
    return torch.stack([x[i,:,:,:,j] for i in range(n) for j in range(d)]),n

def from_axial(x,n):
    if x.dim()==4:
        t,c,h,w=x.shape
        d=t//n
        return torch.stack([torch.stack([x[j*d+i,:,:,:] for i in range(d)],3) for j in range(n)])
    else:
        t,c=x.shape
        d=t//n
        return torch.stack([torch.stack([x[j*d+i,:] for i in range(d)]).mean(0) for j in range(n)])

def to_coronal(x):
    n,c,h,w,d=x.shape
    return torch.stack([x[i,:,:,j,:] for i in range(n) for j in range(w)]),n

def from_coronal(x,n):
    if x.dim()==4:
        t,c,h,d=x.shape
        w=t//n
        return torch.stack([torch.stack([x[j*w+i,:,:,:] for i in range(w)],2) for j in range(n)])
    else:
        t,c=x.shape
        w=t//n
        return torch.stack([torch.stack([x[j*w+i,:] for i in range(w)]).mean(0) for j in range(n)])

This is code to allow a model to be run slicewise over mri images along slices in each of the axial, sagittal and cornal planes and return the average result of these

In [14]:
from copy import deepcopy as copy

class SliceWiseModel(nn.Module):
    def __init__(self,axial_model,sagittal_model,coronal_model):
        super(SliceWiseModel,self).__init__()
        self.axial_model=axial_model
        self.sagittal_model=sagittal_model
        self.coronal_model=coronal_model

    def forward(self,x):
        axial,n=to_axial(x)
        axial_result=from_axial(self.axial_model(axial),n)
        sagittal,n=to_sagittal(x)
        sagittal_result=from_sagittal(self.sagittal_model(sagittal),n)
        coronal,n=to_coronal(x)
        coronal_result=from_coronal(self.coronal_model(coronal),n)
        return (axial_result+sagittal_result+coronal_result)/3

    @classmethod
    def same_model(cls,model):
        return SliceWiseModel(model,copy(model),copy(model))

# Segmentation Models

This is an architecture based on UNet but is shallower and with fewer features

In [15]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, init_features=4, out_channels=1,dropout_rate=0):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels,features,"conv1",dropout_rate=dropout_rate)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = UNet._block(features,2*features,"conv2",dropout_rate=dropout_rate)
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = UNet._block(2*features,4*features,"conv3",dropout_rate=dropout_rate)
        self.upconv2 = nn.ConvTranspose2d(4*features,4*features,2,stride=2)
        self.decoder2 = UNet._block(6*features,2*features,"conv4",dropout_rate=dropout_rate)
        self.upconv1 = nn.ConvTranspose2d(2*features,2*features,2,stride=2)
        self.decoder1 = UNet._block(3*features,features,"conv5",dropout_rate=dropout_rate)
        self.conv = nn.Conv2d(features,out_channels,1)
        self.activation=nn.Sigmoid()

    def forward(self, x):
        res1=self.encoder1(x)
        res2=self.encoder2(self.pool1(res1))
        res3=self.bottleneck(self.pool2(res2))

        temp1=self.upconv2(res3)
        res4=self.decoder2(torch.cat((temp1,res2),dim=1))
        
        temp2=self.upconv1(res4)
        res5=self.decoder1(torch.cat((temp2,res1),dim=1))

        return self.activation(self.conv(res5))

    @staticmethod
    def _block(in_channels, features, name,dropout_rate=0):
        return nn.Sequential(
            OrderedDict(
                [(name,nn.Conv2d(in_channels,features,3,padding=1)),
                 ('dropout',nn.Dropout(dropout_rate)),
                 ("relu",nn.ReLU()),
                 ("batchnorm",nn.BatchNorm2d(features))
                ]))

This is the same as the above UNet model without the residual connections

In [16]:
class CDNN(nn.Module):
    def __init__(self, in_channels=1, init_features=4, out_channels=1,dropout_rate=0):
        super(CDNN, self).__init__()

        features = init_features
        self.encoder1 = CDNN._block(in_channels,features,"conv1",dropout_rate=dropout_rate)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = CDNN._block(features,2*features,"conv2",dropout_rate=dropout_rate)
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = CDNN._block(2*features,4*features,"conv3",dropout_rate=dropout_rate)
        self.upconv2 = nn.ConvTranspose2d(4*features,4*features,2,stride=2)
        self.decoder2 = CDNN._block(4*features,2*features,"conv4",dropout_rate=dropout_rate)
        self.upconv1 = nn.ConvTranspose2d(2*features,2*features,2,stride=2)
        self.decoder1 = CDNN._block(2*features,features,"conv5",dropout_rate=dropout_rate)
        self.conv = nn.Conv2d(features,out_channels,1)
        self.activation=nn.Sigmoid()

    def forward(self, x):
        res=self.encoder1(x)
        res=self.encoder2(self.pool1(res))
        res=self.bottleneck(self.pool2(res))

        res=self.upconv2(res)
        res=self.decoder2(res)
        
        res=self.upconv1(res)
        res=self.decoder1(res)

        return self.activation(self.conv(res))

    @staticmethod
    def _block(in_channels, features, name,dropout_rate=0):
        return nn.Sequential(
            OrderedDict(
                [(name,nn.Conv2d(in_channels,features,3,padding=1)),
                 ('dropout',nn.Dropout(dropout_rate)),
                 ("relu",nn.ReLU()),
                 ("batchnorm",nn.BatchNorm2d(features))
                ]))

This is an architecture based on VNet but altered to be 2D, shallower and with fewer features.

In [17]:
class VNet(nn.Module):

    def __init__(self,in_channels=1,init_features=4,out_channels=1,dropout_rate=0):
        super(VNet,self).__init__()
        features=init_features
        self.features=features

        self.encoder1=VNet._convblock(in_channels,features,1,dropout_rate=dropout_rate)
        self.downconv1=VNet._downconvblock(features)
        self.encoder2=VNet._convblock(2*features,2*features,2,dropout_rate=dropout_rate)
        self.downconv2=VNet._downconvblock(2*features)
        self.encoder3=VNet._convblock(4*features,4*features,3,dropout_rate=dropout_rate)
        self.upconv1=VNet._upconvblock(2*features)
        self.decoder1=VNet._convblock(4*features,2*features,2,dropout_rate=dropout_rate)
        self.upconv2=VNet._upconvblock(features)
        self.decoder2=VNet._convblock(2*features,features,1,dropout_rate=dropout_rate)
        self.conv=nn.Conv2d(features,out_channels,1)
        self.activation=nn.Sigmoid()

    def forward(self,x):
        res1=self.encoder1(x)+x.repeat(1,self.features,1,1)
        temp=self.downconv1(res1)
        res2=self.encoder2(temp)+temp
        temp=self.downconv2(res2)
        temp=self.encoder3(temp)+temp
        temp=self.upconv1(temp)
        temp=self.decoder1(torch.cat((temp,res2),dim=1))+temp
        temp=self.upconv2(temp)
        temp=self.decoder2(torch.cat((temp,res1),dim=1))+temp
        return self.activation(self.conv(temp))

    @staticmethod
    def _conv(in_channels, features, name='conv',dropout_rate=0):
        return nn.Sequential(
            OrderedDict(
                [(name,nn.Conv2d(in_channels,features,5,padding=2)),
                 ('dropout',nn.Dropout(dropout_rate)),
                 ("prelu",nn.PReLU()),
                 ("batchnorm",nn.BatchNorm2d(features))
                ]))
    def _convblock(in_channels,features,num_conv,name='conv',dropout_rate=0):
        return nn.Sequential(
            OrderedDict(
                [(name+str(0),VNet._conv(in_channels,features,dropout_rate=dropout_rate))]
                +[(name+str(i),VNet._conv(features,features,dropout_rate=dropout_rate)) for i in range(1,num_conv)]
                ))
    def _downconvblock(features,name='downconv'):
        return nn.Sequential(
            OrderedDict(
                [(name,nn.Conv2d(features,2*features,2,stride=2))]
                +[("prelu",nn.PReLU())]
                ))
    def _upconvblock(features,name='upconv'):
        return nn.Sequential(
            OrderedDict(
                [(name,nn.ConvTranspose2d(2*features,features,2,stride=2))]
                +[("prelu",nn.PReLU())]
                ))

In [18]:
segmentation_models={
    'CDNN':CDNN,
    'UNet':UNet,
    'VNet':VNet
}

# Discriminators

This a standard feed forward network

In [19]:
class ANN(nn.Module):

  def __init__(self,input_dim,hidden_dims,out_dim=1,act=nn.ReLU(),end=nn.Sigmoid(),dropout_rate=0.5):
    super(ANN,self).__init__()
    layers=[]
    in_dim=input_dim
    for x in hidden_dims:
      layers.append(nn.Linear(in_dim,x))
      nn.init.xavier_uniform_(layers[-1].weight)
      layers.append(nn.Dropout(dropout_rate))
      in_dim=x
    layers.append(nn.Linear(in_dim,out_dim))
    layers[-1].weight.data.fill_(0)
    layers[-1].bias.data.fill_(0)
    self.layers=nn.ModuleList(layers)
    self.act=act
    self.end=end

  def forward(self,x):
    x=x.flatten(1)
    for layer in self.layers[:-1]:
      x=self.act(layer(x))
    return self.end(self.layers[-1](x))

In [20]:
def ANNDiscriminator(hidden_dims=[1000,500,250],dropout_rate=0.5):
    c=2
    h=40
    w=128
    d=128
    axial_model=ANN(c*h*w,hidden_dims,dropout_rate=dropout_rate)
    sagittal_model=ANN(c*w*d,hidden_dims,dropout_rate=dropout_rate)
    coronal_model=ANN(c*h*d,hidden_dims,dropout_rate=dropout_rate)
    return SliceWiseModel(axial_model,sagittal_model,coronal_model)

This is a CNN

In [21]:
class CNN(nn.Module):
    def __init__(self,input_dim,in_channels=2, init_features=4,out_channels=1,fc_hidden_dims=[120,16],dropout_rate=0):
        super(CNN, self).__init__()
        # Convolutional layers
        features = init_features
        self.encoder1 = CNN._block(in_channels, features, name="conv1",dropout_rate=dropout_rate)
        self.pool1 = nn.MaxPool2d(2, stride=2)
        self.encoder2 = CNN._block(features, 2*features, name="conv2",dropout_rate=dropout_rate)
        self.pool2 = nn.MaxPool2d(2, stride=2)
        self.encoder3 = CNN._block(2*features, 4*features, name="conv3",dropout_rate=dropout_rate)

        # Fully connected layers
        self.fc = ANN(int(input_dim/16)*4*features,fc_hidden_dims,out_channels,dropout_rate=0.2)

    def forward(self, x):
        temp = self.encoder1(x)
        temp = self.encoder2(self.pool1(temp))
        temp = self.encoder3(self.pool2(temp))
        pred = self.fc(temp)
        return pred

    @staticmethod
    def _block(in_channels, features, name,dropout_rate=0):
        return nn.Sequential(
            OrderedDict(
                [(name,nn.Conv2d(in_channels,features,3,padding=1)),
                 ('dropout',nn.Dropout(dropout_rate)),
                 ("relu",nn.ReLU()),
                 ("batchnorm",nn.BatchNorm2d(features))
                ]))

In [22]:
def CNNDiscriminator(init_features=4,dropout_rate=0):
    h=40
    w=128
    d=128
    axial_model=CNN(h*w,init_features=init_features,dropout_rate=dropout_rate)
    sagittal_model=CNN(w*d,init_features=init_features,dropout_rate=dropout_rate)
    coronal_model=CNN(h*d,init_features=init_features,dropout_rate=dropout_rate)
    return SliceWiseModel(axial_model,sagittal_model,coronal_model)

This is a transformer model

In [23]:
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(torch.matmul(q,k.T)/np.sqrt(self.d))
                seq_result.append(torch.matmul(attention,v))
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [24]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d,mlp_ratio*hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio*hidden_d,hidden_d)
        )

    def forward(self, x):
        out = x+self.mhsa(self.norm1(x))
        temp = self.norm2(out)
        out = temp+self.mlp(temp)
        return out

In [25]:
def patchify(images, patch_size):
    n, c, h, w = images.shape

    n_patches1=h//patch_size
    n_patches2=w//patch_size

    patches=images.unfold(2,patch_size,patch_size).unfold(3,patch_size,patch_size)

    patches=patches.flatten(4)

    patches=torch.cat((patches[:,0,:,:,:],patches[:,1,:,:,:]),3)

    patches=torch.stack([patches[:,i,j,:] for i in range(n_patches1) for j in range(n_patches2)],1)

    return patches

In [26]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            if j%2==0:
                result[i][j]=np.sin(i/np.power(10000,j/d))
            else:
                result[i][j]=np.cos(i/np.power(10000,(j-1)/d))
    return result

In [27]:
class MyViT(nn.Module):
    def __init__(self, image_shape, patch_size=8, n_blocks=2, hidden_d=16, n_heads=2, out_d=1):
        # Super constructor
        super(MyViT, self).__init__()
        
        # Attributes
        self.image_shape = image_shape # ( C , H , W )
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d
        
        c,h,w=image_shape
        # Input and patches sizes
        assert h%patch_size==0
        assert w%patch_size==0
        n_patches1=h//patch_size
        n_patches2=w//patch_size

        self.patch_size = patch_size
        
        # 1) Linear mapper
        self.input_d = int(image_shape[0] * patch_size**2)
        self.linear_mapper = nn.Linear(self.input_d,hidden_d)
        
        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(hidden_d))
        
        # 3) Positional embedding - this creates a property called self.positional_embeddings
        self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches1*n_patches2 + 1, hidden_d), persistent=False)
        
        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        
        # 5) Classification MLPk
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d,out_d),
            nn.Sigmoid()
        )

    def forward(self, images):
        # Dividing images into patches
        n, c, h, w = images.shape
        patches = patchify(images, self.patch_size)
        
        # Running linear layer tokenization
        # Map the vector corresponding to each patch to the hidden size dimension
        tokens = self.linear_mapper(patches)
        
        # Adding classification token to the tokens
        tokens = torch.cat((self.class_token.repeat(n,1,1),tokens),dim=1)   
        
        # Adding positional embedding
        out = tokens+self.positional_embeddings.repeat(n,1,1)
        
        # Transformer Blocks
        for block in self.blocks:
            out = block(out)
            
        # Getting the classification token only
        out = out[:,0,:]
        
        return self.mlp(out) # Map to output dimension, output category distribution

In [28]:
def ViTDiscriminator(n_blocks=2,n_heads=2,hidden_d=8):
    c=2
    h=40
    w=128
    d=128
    axial_model=MyViT((c,h,w),n_blocks=n_blocks,n_heads=n_heads,hidden_d=hidden_d)
    sagittal_model=MyViT((c,w,d),n_blocks=n_blocks,n_heads=n_heads,hidden_d=hidden_d)
    coronal_model=MyViT((c,h,d),n_blocks=n_blocks,n_heads=n_heads,hidden_d=hidden_d)
    return SliceWiseModel(axial_model,sagittal_model,coronal_model)

In [29]:
discriminators={
    'ANN': ANNDiscriminator,
    'CNN': CNNDiscriminator,
    'ViT': ViTDiscriminator
}

# Training Code

Loss Functions

In [30]:
def dummy_discriminator(x):
    n=len(x)
    return torch.ones(n,1)

class GANSegLoss(nn.Module):
    def __init__(self,discriminator=dummy_discriminator,lamda=1):
        super(GANSegLoss,self).__init__()
        self.discriminator=discriminator
        self.lamda=lamda
        self.loss=nn.BCELoss()

    def forward(self,image,seg_pred,seg_target):
        first_term=self.loss(seg_pred,seg_target)
        dis_pred=self.discriminator(torch.cat((image,seg_pred),1))
        second_term=self.lamda*self.loss(dis_pred,torch.ones(dis_pred.shape).to(device))
        return first_term+second_term

    def set_discriminator(self,discriminator):
        self.discriminator=discriminator

class GANDisLoss(nn.Module):
    def __init__(self):
        super(GANDisLoss,self).__init__()
        self.loss=nn.BCELoss()
    
    def forward(self,true_pred,false_pred):
        true,false=true_pred.flatten(),false_pred.flatten()
        return self.loss(torch.cat((true,false),0),torch.cat((torch.ones(true.shape).to(device),torch.zeros(false.shape).to(device)),0))

Training step for the segmentation network and discriminator network in GAN training and a function to evaluate the dice score of the segmentation network on the validation set

In [31]:
from sklearn.metrics import accuracy_score,f1_score

def seg_train(seg_net,dis_net, dataloader, optim, loss_func, epoch):
    seg_net.train()  #Put the network in train mode
    dis_net.eval()
    total_loss = 0
    pred_store = []
    true_store = []
    
    batches = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):

        data, target = Variable(data), Variable(target)
        batches += 1

        # Define training process here:

        seg_net.zero_grad()
        optim.zero_grad() 

        pred=seg_net(data)
        loss=loss_func(data.float(),pred.float(),target.float())
        
        loss.backward()
        optim.step()  
        
        total_loss += loss
        pred_store.append(np.round(pred.detach().cpu().numpy()))
        true_store.append(np.round(target.detach().cpu().numpy()))

        if batch_idx % 100 == 0: #Report stats every x batches
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx+1) * len(data), len(dataloader.dataset),
                           100. * (batch_idx+1) / len(dataloader), loss.item()), flush=True)

    
    av_loss = total_loss / batches
    av_loss = av_loss.detach().cpu().numpy()

    pred_store = np.concatenate(pred_store)
    true_store = np.concatenate(true_store)
    acc = accuracy_score(pred_store.flatten(), true_store.flatten())

    print('\nTraining set: Average loss: {:.4f}'.format(av_loss,  flush=True))
    print('Training set: Average Pixel Acc: {:.4f}'.format(acc,  flush=True))
    return av_loss, acc

def dis_train(seg_net,dis_net, dataloader, optim, loss_func, epoch):
    seg_net.eval()  #Put the network in train mode
    dis_net.train()
    total_loss = 0
    pred_store = []
    true_store = []

    batches = 0
    
    for batch_idx, (true,false) in enumerate(dataloader):
        true,false = Variable(true), Variable(false)
        batches += 1

        dis_net.zero_grad()
        optim.zero_grad() 

        true_pred=dis_net(true)
        false_pred=dis_net(false)
        loss=loss_func(true_pred.float(),false_pred.float())
        
        loss.backward()
        optim.step()  
        
        total_loss += loss
        pred_store.append(np.round(true_pred.detach().cpu().numpy()))
        pred_store.append(np.round(false_pred.detach().cpu().numpy()))
        true_store.append(torch.ones(true_pred.shape).numpy())
        true_store.append(torch.zeros(false_pred.shape).numpy())

        if batch_idx % 100 == 0: #Report stats every x batches
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx+1) * len(true), len(dataloader.dataset),
                           100. * (batch_idx+1) / len(dataloader), loss.item()), flush=True)
    
    av_loss = total_loss / batches
    av_loss = av_loss.detach().cpu().numpy()

    pred_store = np.concatenate(pred_store)
    true_store = np.concatenate(true_store)
    acc = accuracy_score(pred_store, true_store)

    print('\nTraining set: Average loss: {:.4f}'.format(av_loss,  flush=True))
    print('Training set: Average Acc: {:.4f}'.format(acc,  flush=True))
    return av_loss, acc

def predict(net, test_dataloader,class_labels=True):
    pred_store = []
    true_store = []
    net.eval()
    for batch_idx, (data, target) in enumerate(test_dataloader):

        data, target = Variable(data), Variable(target)
        pred=net(data)
        if class_labels:
            pred_store.append(np.round(pred.detach().cpu().numpy()))
            true_store.append(np.round(target.detach().cpu().numpy()))
        else:
            pred_store.append(pred.detach().cpu().numpy())
            true_store.append(target.detach().cpu().numpy())


    pred_store = np.concatenate(pred_store)
    true_store = np.concatenate(true_store)
    return pred_store, true_store

def seg_val(seg_net, val_dataloader):
    pred_store,true_store=predict(seg_net,val_dataloader)
    pred_store=pred_store.flatten()
    true_store=true_store.flatten()
    temp=f1_score(true_store,pred_store)
    print('Validation Set: Dice Score ',temp)
    return temp

A function to train the networks in a GAN

In [32]:
def GANtrain(seg_net,dis_net,seg_dataloader,dis_dataloader,val_dataloader,max_epochs=50,lr1=0.01,lr2=0.01,n1=1,n2=1,lamda=1):
    dice = []
    seg_losses=[]
    dis_losses=[]
    seg_path='seg_model'
    dis_path='dis_model'
    best_dice=seg_val(seg_net,val_dataloader)
    torch.save(seg_net.state_dict(),seg_path)
    torch.save(dis_net.state_dict(),dis_path)
    seg_optim=torch.optim.Adam(seg_net.parameters(),lr=lr1)
    dis_optim=torch.optim.Adam(dis_net.parameters(),lr=lr2)
    seg_loss=GANSegLoss(dis_net,lamda).to(device)
    dis_loss=GANDisLoss().to(device)
    dis_dataloader.set_model(seg_net)
    for epoch in range(1, max_epochs+1):
        for _ in range(n1):
            train_loss, train_acc = seg_train(seg_net,dis_net, seg_dataloader, seg_optim, seg_loss, epoch)
            seg_losses.append([train_loss,train_acc])
        for _ in range(n2):
            train_loss, train_acc = dis_train(seg_net,dis_net, dis_dataloader, dis_optim, dis_loss, epoch)
            dis_losses.append([train_loss,train_acc])
        die = seg_val(seg_net, val_dataloader)
        dice.append(die)
        if die>best_dice:
            best_dice=die
            torch.save(seg_net.state_dict(),seg_path)
            torch.save(dis_net.state_dict(),dis_path)
    seg_net.load_state_dict(torch.load(seg_path))
    dis_net.load_state_dict(torch.load(dis_path))
    return dice,seg_losses,dis_losses

A function to choose the best lambda value by training the model for 3 epochs and selecting the lambda value that gave the best dice score on the validation data

In [33]:
def get_lambda(seg_name,dis_name):
    try:
        with open('/content/drive/MyDrive/healthcare data/lambda'+seg_name+dis_name,'rb') as f:
            return pickle.load(f)
    except:
        torch.manual_seed(2718281828)
        values=[0.01,0.1,0.5]
        best_dice=0
        best_value=0
        for value in values:
            seg_net=SliceWiseModel.same_model(segmentation_models[seg_name]().to(device))
            dis_net=discriminators[dis_name]()
            n=2 if (dis_name=='ViT' or dis_name=='CNN') else 1
            dice,_,_=GANtrain(seg_net,dis_net,train_dataloader,train_discriminator_dataloader,val_dataloader,max_epochs=3,lamda=value,n1=n)
            dice=max(dice)
            del seg_net
            del dis_net
            if dice>best_dice:
                best_dice=dice
                best_value=value
        with open('/content/drive/MyDrive/healthcare data/lambda'+seg_name+dis_name,'wb') as f:
            pickle.dump(best_value,f)
        return best_value

training and validation steps for the control models

In [34]:
def train(net, dataloader, optim, loss_func, epoch):
    net.train()  #Put the network in train mode
    total_loss = 0
    batches = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = Variable(data), Variable(target)
        batches += 1

        net.zero_grad()
        optim.zero_grad() 

        pred=net(data)
        loss=loss_func(pred.float(),target.float())
        
        loss.backward()
        optim.step()

        total_loss += loss
        if batch_idx % 100 == 0: #Report stats every x batches
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx+1) * len(data), len(dataloader.dataset),
                           100. * (batch_idx+1) / len(dataloader), loss.item()), flush=True)
    av_loss = total_loss / batches
    av_loss = av_loss.detach().cpu().numpy()
    print('\nTraining set: Average loss: {:.4f}'.format(av_loss,  flush=True))

    return av_loss

def val(net, val_dataloader, loss_func):
    net.eval()  #Put the model in eval mode
    total_loss = 0    
    batches = 0
    with torch.no_grad():  # So no gradients accumulate
        for batch_idx, (data, target) in enumerate(val_dataloader):
            batches += 1
            data, target = Variable(data), Variable(target)
            # Eval steps
            pred=net(data)
            loss=loss_func(pred.float(),target.float())

            total_loss += loss
        av_loss = total_loss / batches
        
    av_loss = av_loss.detach().cpu().numpy()
    print('Validation set: Average loss: {:.4f}'.format(av_loss,  flush=True))
    print('\n')
    return av_loss

def val_acc(net,val_dataloader):
    pred_store,true_store=predict(net,val_dataloader)
    temp=accuracy_score(true_store.flatten(),pred_store.flatten())
    print('Validation Set: Accuracy ',temp)
    return temp

A function to train a segmentation network

In [35]:
def controlTrain(net,train_dataloader,val_dataloader,class_loss=nn.BCELoss().to(device),lr=0.01):
    losses = []
    max_epochs = 25
    optim=torch.optim.Adam(net.parameters(),lr=lr)
    path='model'
    best_dice=seg_val(net,val_dataloader)
    torch.save(net.state_dict(),path)
    dice=[]
    for epoch in range(1, max_epochs+1):
        train_loss = train(net, train_dataloader, optim, class_loss, epoch)
        val_loss = val(net, val_dataloader,class_loss)
        losses.append([train_loss, val_loss])
        die = seg_val(net, val_dataloader)
        dice.append(die)
        if die>best_dice:
            best_dice=die
            torch.save(net.state_dict(),path)
    net.load_state_dict(torch.load(path))
    return dice,losses

A function to evaluate a segmentation model on the test data

In [36]:
from sklearn.metrics import accuracy_score, jaccard_score, f1_score,precision_score,recall_score,roc_auc_score,confusion_matrix

def evaluate(net,dataloader):
    torch.manual_seed(271828182)
    pred_store,true_store=predict(net,dataloader)
    pred_store=pred_store.flatten()
    true_store=true_store.flatten()
    print('Accuracy: ',accuracy_score(true_store,pred_store))
    print("Jaccard's Index: ",jaccard_score(true_store,pred_store))
    print('Dice Score: ',f1_score(true_store,pred_store))
    print('Precision: ',precision_score(true_store,pred_store))
    print('Recall: ',recall_score(true_store,pred_store))
    tn, fp, fn, tp = confusion_matrix(true_store,pred_store).ravel()
    print('Specificity: ',tn/(tn+fp))
    pred_store,true_store=predict(net,dataloader,False)
    pred_store=pred_store.flatten()
    true_store=true_store.flatten()
    print('AUC-ROC Score: ',roc_auc_score(true_store,pred_store))

# Training the control networks

In [None]:
def get_control_model(seg_name):
    try:
        with open('/content/drive/MyDrive/healthcare data/'+seg_name,'rb') as f:
            return pickle.load(f)
    except Exception:
        torch.manual_seed(2718281828)  #manually seed the pytorch pseudo-random generators (with a seed based on e) to ensure repeatability
        model=SliceWiseModel.same_model(segmentation_models[seg_name]().to(device))
        model_stats=controlTrain(model,train_dataloader,val_dataloader)
        with open('/content/drive/MyDrive/healthcare data/'+seg_name,'wb') as f:
            pickle.dump((model,model_stats),f)

In [None]:
for model_name in segmentation_models:
    _=get_control_model(model_name)

#Evaluating Control Models

In [None]:
for seg_name in segmentation_models:
    print('Control Model: ',seg_name)
    model,_=get_control_model(seg_name)
    evaluate(model,test_dataloader)
    print()
    del model

Control Model:  CDNN
Accuracy:  0.938238525390625
Jaccard's Index:  0.8371325381403859
Dice Score:  0.911346917830722
Precision:  0.8782722773580912
Recall:  0.9470101406529788
Specificity:  0.9338154838225938
AUC-ROC Score:  0.9748269320586742

Control Model:  UNet
Accuracy:  0.9467086065383185
Jaccard's Index:  0.8545759176422068
Dice Score:  0.9215863416674382
Precision:  0.9092912209014442
Recall:  0.9342185207431117
Specificity:  0.9530066676146601
AUC-ROC Score:  0.9782274071342195

Control Model:  VNet
Accuracy:  0.9674538748604911
Jaccard's Index:  0.9062341427143151
Dice Score:  0.9508109443721482
Precision:  0.9635908872268616
Recall:  0.9383655607532041
Specificity:  0.9821215066013244
AUC-ROC Score:  0.9866804274400738



#Training GAN Models

In [37]:
def get_GAN_model(seg_name,dis_name):
    try:
        with open('/content/drive/MyDrive/healthcare data/'+seg_name+dis_name,'rb') as f:
            return pickle.load(f)
    except Exception:
        torch.manual_seed(2718281828)  #manually seed the pytorch pseudo-random generators (with a seed based on e) to ensure repeatability
        lamda=get_lambda(seg_name,dis_name)
        seg_net=SliceWiseModel.same_model(segmentation_models[seg_name]()).to(device)
        dis_net=discriminators[dis_name]().to(device)
        n=2 if (dis_name=='ViT' or dis_name=='CNN') else 1
        results=GANtrain(seg_net,dis_net,train_dataloader,train_discriminator_dataloader,val_dataloader,lamda=lamda,n1=n)
        with open('/content/drive/MyDrive/healthcare data/'+seg_name+dis_name,'wb') as f:
            pickle.dump(((seg_net,dis_net),results),f)
        return (seg_net,dis_net),results

In [None]:
for seg_name in segmentation_models:
    for dis_name in discriminators:
        _=get_GAN_model(seg_name,dis_name)

In [38]:
_=get_GAN_model('CDNN','ViT')

Validation Set: Dice Score  0.5366234891401189

Training set: Average loss: 0.9431
Training set: Average Pixel Acc: 0.8023

Training set: Average loss: 0.7968
Training set: Average Pixel Acc: 0.8947

Training set: Average loss: 0.5691
Training set: Average Acc: 0.8036
Validation Set: Dice Score  0.8134520717396918

Training set: Average loss: 0.7514
Training set: Average Pixel Acc: 0.8646

Training set: Average loss: 0.5228
Training set: Average Pixel Acc: 0.8821

Training set: Average loss: 0.4818
Training set: Average Acc: 0.7946
Validation Set: Dice Score  0.8101106662293046

Training set: Average loss: 0.7557
Training set: Average Pixel Acc: 0.9057

Training set: Average loss: 0.4335
Training set: Average Pixel Acc: 0.9138

Training set: Average loss: 0.6738
Training set: Average Acc: 0.6161
Validation Set: Dice Score  0.7648848823541561

Training set: Average loss: 1.3578
Training set: Average Pixel Acc: 0.8662

Training set: Average loss: 1.1125
Training set: Average Pixel Acc: 0

In [None]:
for seg_name in segmentation_models:
    _=get_GAN_model(seg_name,'CNN')

# Evaluating GAN models

In [None]:
for seg_name in segmentation_models:
    for dis_name in discriminators:
        print('Segmentation model: ',seg_name)
        print('Discriminator: ',dis_name)
        (seg_net,_),_=get_GAN_model(seg_name,dis_name)
        evaluate(seg_net,test_dataloader)
        print()
        del seg_net
    print()