# DCGAN128 for fetal head ultrasound images

**Author(s)**: Thea Bautista [@theabautista](https://github.com/theabautista)     
**Contributor(s)**:  Miguel Xochicale [@mxochicale](https://github.com/mxochicale)     

May2022


## Summary
This notebook presents a learning pipeline to classify 4 chamber view from echocardiography datasets.

### How to run the notebook

1. Go to the repository path: `cd $HOME/repositories/xfetus/miua2022`
2. Open repo in pycharm and in the terminal type:
    ```
    git checkout master # or the branch
    git pull # to bring a local branch up-to-date with its remote version
    ```
3. Launch Notebook server  
    Go to notebooks path: `cd $HOME/repositories/xfetus/miua2022/notebooks` and type in the pycharm terminal:
    ```
    conda activate susiE 
    jupyter notebook
    ```
    which will open your web-browser.
    
    
### References
* "Proposed Regulatory Framework for Modifications to Artificial Intelligence/Machine Learning (AI/ML)-Based Software as a Medical Device (SaMD) - Discussion Paper and Request for Feedback". https://www.fda.gov/media/122535/download 


## Dependencies

In [None]:
!pip install pytorch-gan-metrics
!pip install --quiet "torchmetrics>=0.3"

In [None]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch import optim as optim
import numpy as np
import torchvision.utils as vutils
import matplotlib.animation as animation
from IPython.display import HTML
from torchmetrics import Accuracy
from pytorch_gan_metrics import get_inception_score_and_fid
from PIL import Image

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
!unzip '../head_cirumference.zip'

## Hyperparameters

In [None]:
!cd '../DC-GANS/results'

In [None]:
# Batch size
batch_size = 10

# Training image size
image_size = 128

# Number of channels in image
nc = 1

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Size of feature maps in generator
ngf = 128

# Size of feature maps in discriminator
ndf = 128

# Number of training epochs
num_epochs = 1000

# Learning rate for optimizers
lr = 0.0002

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Size of training set
subset_size=100

# Checkpoint path
checkpoint_path = "../results/DCGAN128/100_train.pt"

# Dataset path
data_path = '/content/head_cirumference'

# Degrees of rotation
degrees_of_rot=10

## Defining transformations

In [None]:

train_dataset = datasets.ImageFolder(
    root=data_path,
    transform=transforms.Compose([transforms.ToTensor(),
                                  transforms.Grayscale(),
                                  transforms.Normalize((0.5,), (0.5,)), 
                                  transforms.Resize((image_size,image_size)),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomRotation(degrees_of_rot)
    ]))

train_data_subset = torch.utils.data.Subset(train_dataset, np.random.choice(len(train_dataset), subset_size, replace=False))

dataloader = torch.utils.data.DataLoader(
    train_data_subset,
    batch_size=batch_size,
    shuffle=True
)


In [None]:
num_batches = len(dataloader)
print("Number of batches: ",num_batches)

## Display image

In [None]:
%matplotlib inline 
from matplotlib import pyplot as plt

for x,_ in dataloader:
    plt.imshow(x.numpy()[0][0], cmap='gray')
    print(x.numpy()[0][0].shape)
    print(x.shape)
    break

## Generator

In [None]:
def weights_init(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.002)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.002)
        nn.init.constant_(model.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True),
            # state size. (ngf*16) x 4 x 4
            nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 8 x 8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 16 x 16 
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 32 x 32
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 64 x 64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 128 x 128
        )

    def forward(self, input):
        return self.main(input)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 128 x 128
            nn.Conv2d(nc, ndf, 4, stride=2, padding=1, bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 64 x 64
            nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 32 x 32
            nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 16 x 16 
            nn.Conv2d(ndf * 4, ndf * 8, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 8 x 8
            nn.Conv2d(ndf * 8, ndf * 16, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*16) x 4 x 4
            nn.Conv2d(ndf * 16, 1, 4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
            # state size. 1
        )

    def forward(self, input):
        return self.main(input)

## Initialise network

In [None]:
def save_checkpoint(state, output_file='../results/DCGAN128/100_training.pt'):
  """Function which saves a checkpoint containing model state into a file"""
  print("Saving checkpoint at epoch : ", state['epoch']-1)
  torch.save(state, output_file)

In [None]:
# Create the generator
netG = Generator(ngpu).to(device)
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)
netD.apply(weights_init)

# Print the model
print(netG)
print(netD)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 0.9
fake_label = 0.

# Hyperparameters
gen_lr = 2e-3
dis_lr = 2e-4

wd_gen = 1e-2
wd_dis = 1e-1

beta1 = 0.5

# Setup Adam optimizers for both G and D
optimizerG = optim.Adam(netG.parameters(), lr=gen_lr, betas=(beta1, 0.999), weight_decay=wd_gen)
optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999), weight_decay=wd_gen)



## Training

In [None]:
# Training Loop
last_epoch = 0
load_checkpoint = False

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
D_acc_real = []
D_acc_gen = []


if load_checkpoint:
    print(f"loading checkpoint {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    last_epoch = checkpoint['epoch']
    netG.load_state_dict(checkpoint['modelG_state_dict'])
    netD.load_state_dict(checkpoint['modelD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    G_losses = checkpoint['G_losses']
    D_losses = checkpoint['D_loss']
    D_acc_real = checkpoint['D_acc_real']
    D_acc_gen = checkpoint['D_acc_gen']

netG.train()
netD.train()

iters = 0
accuracy = Accuracy().to(device)
print("Starting Training Loop...")
# For each epoch
for epoch in range(last_epoch,num_epochs):
    # For each batch in the dataloader
    G_losses_tmp = []
    D_losses_tmp = []
    d_acc_real_tmp = []
    d_acc_gen_tmp = []
    

    for i, data in enumerate(dataloader, 0):

        ######################
        # Update discriminator
        ######################
        # Set all gradients to zero
        netD.zero_grad()
        # moving real data to gpu
        real_cpu = data[0].to(device)
        # size of current batch
        b_size = real_cpu.size(0)
        # initialising labels for real data
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # forward pass real data through D
        output = netD(real_cpu).view(-1)
        # calculate loss on real batch
        errD_real = criterion(output, label)
        # calculate accuracy on real batch
        label_tmp = label.int()
        d_acc_real_tmp.append(accuracy(output, label_tmp))
        # backpropagate
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # get fake images
        fake = netG(noise)
        # get fake labels
        label.fill_(fake_label)
        # forward pass with fake data
        output = netD(fake.detach()).view(-1)
        # calculate loss on fake batch
        errD_fake = criterion(output, label)
        # backpropagate
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # calculate total error of discriminator
        errD = errD_real + errD_fake
        # compute accuracy on fake images
        d_acc_gen_tmp.append(accuracy(output, label.int()))
        # update D
        optimizerD.step()

        ##################
        # Update generator
        ##################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of fake batch through D
        output = netD(fake).view(-1)
        # calculate generator loss
        errG = criterion(output, label)
        # backpropagate
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            with torch.no_grad():
                generated_data = netG(noise).cpu().view(b_size, 64,64)
                for x in generated_data:
                    plt.imshow(x.detach().numpy(), interpolation='nearest',cmap='gray')
                    plt.show()
                    plt.savefig("/")
                    break

        # Save losses for plotting later
        G_losses_tmp.append(errG.item())
        D_losses_tmp.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

    
    # Save accuracy of discriminator
    D_acc_real.append(torch.mean(torch.Tensor(d_acc_real_tmp)))
    D_acc_gen.append(torch.mean(torch.Tensor(d_acc_gen_tmp)))
    
    # Save Losses for plotting later
    G_losses.append(torch.mean(torch.Tensor(G_losses_tmp)))
    D_losses.append(torch.mean(torch.Tensor(D_losses_tmp)))

    print(f'[{epoch}/{num_epochs}] G_loss: {G_losses[-1]}, D_loss: {D_losses[-1]}')

    # save checkpoint
    if (epoch+1) % 5 == 0:
        checkpoint = {
          'epoch': epoch+1,
          'modelG_state_dict': netG.state_dict(),
          'modelD_state_dict': netD.state_dict(),
          'optimizerG_state_dict': optimizerG.state_dict(),
          'optimizerD_state_dict': optimizerD.state_dict(),
          'G_losses': G_losses,
          'D_loss': D_losses,
          'D_acc_real': D_acc_real,
          'D_acc_gen': D_acc_gen
        }
        save_checkpoint(checkpoint, checkpoint_path)
        if epoch+1 in [300, 500, 800, 1000]:
            checkpoint_filename = f'../results/DCGAN64/org_params/800_train_subset_{epoch+1}_epochs.pt'
            save_checkpoint(checkpoint, checkpoint_filename)


## Plotting figures

In [None]:
plt.figure(figsize=(9,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="Generator")
plt.plot(D_losses,label="Discriminator")
plt.xlabel("iterations")
plt.ylabel("BCE Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(9,5))
plt.title("Accuracy of Discriminator")
plt.plot(D_acc_gen,label="Accuracy on generated images")
plt.plot(D_acc_real,label="Accuracy on real images")
plt.xlabel("Iterations")
plt.ylabel("Accuracy")
plt.legend()
plt.show()