In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [2]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import pickle
import math

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
# Reads data from a pickle file
# Data is stored as a lists of images and ground truths in a dictionary
# Each index in the seperate list corrispond to each other
class PickleDataset(Dataset):
  def __init__(self, pickle_path, transform=None):
    # Opens file and reads lists in
    file = open(pickle_path, 'rb')
    data = pickle.load(file)
    file.close()
    self.features = data["features"]
    self.labels = data["labels"]
    #
    self.transform = transform

  def __len__(self):
    return self.features.shape[0]

  def __getitem__(self, idx):
    feature = self.features[idx]
    label = self.labels[idx]

    x = [1*(label==0),1*(label==1),1*(label==2),1*(label==3)]
    label = torch.tensor(x)
    # label = torch.tensor(label).unsqueeze(0)


    if self.transform:
      feature = self.transform(feature)

    return feature, label


## Training Data

In [None]:
train_path = 'drive/MyDrive/College/Research/ACDC_DataSet/processed_henry/train_data.pkl'

# # Set seed for repeatability
seed_val = 1
torch.manual_seed(seed_val)
np.random.seed(seed_val)

# Define the validation ratio to be used
valid_ratio = 0.9

transform = torchvision.transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
          ])
label_transform = transforms.ToTensor()

train_dataset = PickleDataset(train_path, transform=transform)#, target_transform=label_transform)

# Number of images to use for training
nb_train = math.ceil((1.0 - valid_ratio) * len(train_dataset))
# Number of images to use for validation
nb_valid =  math.floor(valid_ratio * len(train_dataset))

# Randomly split into training and validation data
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(train_dataset, [nb_train, nb_valid])

print(f"Training dataset length: {len(train_dataset)}")
print(f"Valid dataset length: {len(valid_dataset)}")

# create your dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=20, shuffle=True)

print(f"Training batches: {len(train_dataloader)}")
print(f"Valid batches: {len(valid_dataloader)}")

## Test Data

In [None]:
test_path = 'drive/MyDrive/College/Research/ACDC_DataSet/processed_henry/test_data.pkl'

test_dataset = PickleDataset(test_path, transform=transform)#, target_transform=label_transform)
print(f"Training dataset length: {len(train_dataset)}")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
print(f"Test batches: {len(test_dataloader)}")

# Define Network

In [6]:
# Two 3x3 convolutions that maintain image size
# Each convolution is followed by a relu activation
# Inputs
#   in_c: number of channels/features inputted into the convolution
#   mid_c: number of channels/features between the two convolutions
#   out_c: number of channels/features outputted by the convolution
def double_conv(in_c, mid_c, out_c):
  conv = nn.Sequential(
      nn.Conv2d(
        in_channels = in_c,              
        out_channels = mid_c,            
        kernel_size = 3,                                 
        padding = 1                 
      ),                              
      nn.ReLU(inplace=True),
      nn.Conv2d(
        in_channels = mid_c,              
        out_channels = out_c,            
        kernel_size = 3,
        padding = 1                            
      ),                              
      nn.ReLU(inplace=True)
  )
  return conv

In [7]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    # encoder
    self.pool = nn.MaxPool2d(kernel_size=2, stride = 2) #Max pool, decreases image size by half in each dimension, by 4 total
    # each convolution increases features by 2 (exception in down_conv1)
    self.down_conv1 = double_conv(1, 64, 64)
    self.down_conv2 = double_conv(64, 128, 128)
    # bottle conv
    self.bottle_conv = double_conv(128, 256, 128) #Exapanda and contracts features
    # decoder
    # each transpose convolution increases image size by 2 in each dimension, 4 in total
    # each up_conv decreases features by 4
    self.trans1 = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.up_conv1 = double_conv(256, 128, 64)
    self.trans2 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.up_conv2 = double_conv(128, 64, 64)
    # returns number of features equal to classes (4)
    self.out = nn.Conv2d(in_channels = 64, out_channels = 4, kernel_size = 3, padding = 1)

  def forward(self,img):
    # encoder
    # a down convolution followed by max pool
    # increases features by 2, decreases image size by 4
    #------------------------
    x1 = self.down_conv1(img)
    p1 = self.pool(x1)
    #------------------------
    x2 = self.down_conv2(p1)
    p2 = self.pool(x2)
    #------------------------
    # bottleneck
    # adds more weights for network to change
    #------------------------
    bottle = self.bottle_conv(p2)
    #------------------------
    # decoder
    # First increases image size
    # Then concatanates a similar sized output from a down_conv to the results
    # Then does an up_conv decreasing features
    #------------------------
    u2 = self.trans1(bottle)
    c2 = torch.cat([x2,u2], dim=1)
    y2 = self.up_conv1(c2)
    #------------------------
    u1 = self.trans2(y2)
    c1 = torch.cat([x1,u1], dim=1)
    y1 = self.up_conv2(c1)
    #------------------------
    # retirns image with proper number of features
    out = self.out(y1)
    return out

  # Initializes model with normally distributed xavier weights for all convolutions
  # uses gain of sqrt(2)
  def initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal(m.weight,gain=np.sqrt(2))

In [8]:
class Descriminator(nn.Module):
  def __init__(self):
    super(Descriminator, self).__init__()
    self.conv1 = nn.Sequential(
        double_conv(4,16,16),
        nn.MaxPool2d(kernel_size=2, stride = 2)
    )
    self.conv2 = nn.Sequential(
        double_conv(16,32,32),
        nn.MaxPool2d(kernel_size=2, stride = 2)
    )
    self.conv3 = nn.Sequential(
        double_conv(32,64,64),
        nn.MaxPool2d(kernel_size=2, stride = 2)
    )
    self.conv4 = nn.Sequential(
        double_conv(64,128,128),
        nn.MaxPool2d(kernel_size=2, stride = 2)
    )

    self.linear1 = nn.Sequential(
        nn.Linear(in_features = 128*16*16, out_features=256),
        nn.ReLU(inplace=True)
    )
    self.linear2 = nn.Sequential(
        nn.Linear(in_features = 256, out_features=64),
        nn.ReLU(inplace=True)
    )
    self.linear3 = nn.Sequential(
        nn.Linear(in_features = 64, out_features=16),
        nn.ReLU(inplace=True)
    )
    self.linear4 = nn.Linear(in_features = 16, out_features=1)

  def forward(self,x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)

    x = x.view(x.size(0), -1)

    x = self.linear1(x)
    x = self.linear2(x)
    x = self.linear3(x)
    x = self.linear4(x)
    return x

  def initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal(m.weight,gain=np.sqrt(2))

In [None]:
# Selects device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [10]:
class Activation_g(nn.Module):
  def __init__(self,divergence="GAN"):
    super(Activation_g,self).__init__()
    self.divergence =divergence
    print("Activation computed using ",divergence)

  def forward(self,v):
    divergence = self.divergence
    if divergence == "KLD":
      return v
    elif divergence == "RKL":
      return -torch.exp(-v)
    elif divergence == "CHI":
      return v
    elif divergence == "SQH":
      return 1-torch.exp(-v)
    elif divergence == "JSD":
      return torch.log(torch.tensor(2.))-torch.log(1.0+torch.exp(-v))
    elif divergence == "GAN":
      return -torch.log(1.0+torch.exp(-v)) # log sigmoid

class Conjugate_f(nn.Module):
  def __init__(self,divergence="GAN"):
    super(Conjugate_f,self).__init__()
    self.divergence = divergence
    print("Conjugate computed using ",divergence)

  def forward(self,t):
    divergence= self.divergence
    if divergence == "KLD":
      return torch.exp(t-1)
    elif divergence == "RKL":
      return -1 -torch.log(-t)
    elif divergence == "CHI":
      return 0.25*t**2+t
    elif divergence == "SQH":
      return t/(torch.tensor(1.)-t)
    elif divergence == "JSD":
      return -torch.log(2.0-torch.exp(t))
    elif divergence == "GAN":
      return  -torch.log(1.0-torch.exp(t))

class VLOSS(nn.Module):
  def __init__(self,divergence="GAN"):
    super(VLOSS,self).__init__()
    self.activation = Activation_g(divergence)
  def forward(self,v):
    return torch.mean(self.activation(v))


class QLOSS(nn.Module):
  def __init__(self,divergence="GAN"):
    super(QLOSS,self).__init__()
    self.conjugate = Conjugate_f(divergence)
    self.activation = Activation_g(divergence)
  def forward(self,v):
    return torch.mean(-self.conjugate(self.activation(v))) 

In [None]:
#Model parameters
params = {'beta1': 0.5, 'beta2': 0.999,'lr_g':0.00001,'lr_d':0.001,'max_epochs':400}

# Selects loss function
# params['divergence'] = "SQH"
params['divergence'] = "GAN"
# params['divergence'] = "RKL"
# params['divergence'] = "KLD"
# params['divergence'] = "JSD"
# params['divergence'] = "CHI"

#Creates Models
D = Descriminator().to(device)
G = Generator().to(device)
#Optimizers
G_optimizer = optim.Adam(Generator.parameters(G), lr=params['lr_g'], betas=(params['beta1'], params['beta2']))
D_optimizer = optim.Adam(Descriminator.parameters(D), lr=params['lr_g'], betas=(params['beta1'], params['beta2']))

#Defining loss functions
Q_criterion = QLOSS(params['divergence']).to(device)
V_criterion = VLOSS(params['divergence']).to(device)

#Model output
Ninner = 1
train_hist = {}
train_hist = {}
train_hist['D_loss_fake'] = []
train_hist['D_loss_true'] = []
train_hist['D_loss_total'] = []
train_hist['G_loss'] = []

for epoch in range(params['max_epochs']):
  for inputs, labels in train_dataloader:
    #Sends data to device
    z_, x_ = inputs.to(device), labels.to(device)
    batch_sz = inputs.shape[0]
    y_real_, y_fake_ = torch.zeros(batch_sz, 1), torch.ones(batch_sz, 1)
    y_real_, y_fake_ = y_real_.to(device), y_fake_.to(device)
    #Update D network
    for i in range(Ninner):
      D_optimizer.zero_grad()
      D_real = D(x_.float())
      D_real_loss = -V_criterion(D_real)

      G_ = G(z_)
      D_fake = D(G_)
      D_fake_loss = -Q_criterion(D_fake)

      D_loss = D_real_loss + D_fake_loss
      D_loss.backward()
      D_optimizer.step()

    # update G network
    for i in range(Ninner):
      G_optimizer.zero_grad()
      G_ = G(z_)
      D_fake = D(G_)
      G_loss = -V_criterion(D_fake)

      G_loss.backward()
      G_optimizer.step()

    train_hist['D_loss_fake'].append(D_fake_loss.item())
    train_hist['D_loss_true'].append(D_real_loss.item())
    train_hist['D_loss_total'].append(D_loss.item())
    train_hist['G_loss'].append(G_loss.item())

  if(np.mod(epoch,50)==0):
    print("Epoch:", epoch)
    print("\tDloss =",D_loss.detach().cpu().numpy(),";Gloss=",G_loss.detach().cpu().numpy())
    z_ = z_.to(device)

plt.figure()
s=plt.plot(train_hist['D_loss_fake'],c='b')
s=plt.plot(train_hist['D_loss_true'],c='m')
s=plt.plot(train_hist['D_loss_total'],c='r')
s=plt.plot(train_hist['G_loss'],c='k')
s = plt.ylim((0,3))
s = plt.grid()
s=plt.legend(('Dloss_fake','D_loss_true','Discriminator loss','Generator loss'))

In [13]:
# Displays a tensor as an image
def imshow(img):
  img = (img - torch.min(img))/torch.max(img - torch.min(img))
  np_img = img.numpy()      # convert to numpy
  plt.imshow(np.transpose(np_img, (1,2,0)))
  plt.show()   

In [None]:
# Get random testing images
data_iter = iter(test_dataloader)
imgs, labels = data_iter.next()
# Show images (that are currently tensors) and their predictions
n_images = 0
imshow(torchvision.utils.make_grid(imgs[n_images,]))
_,labels = torch.max(labels,1)
imshow(torchvision.utils.make_grid(labels[n_images,]))
# forward pass through network
with torch.no_grad():
  outputs = G(imgs[n_images].unsqueeze(dim=1).to(device))
  _, predicted = torch.max(outputs, 1)
# predictions
imshow(torchvision.utils.make_grid(predicted.cpu().detach()))