<h1 style="font-size: 24px;">Example Code: Least-squares Generative Adversarial Network to generate microbubble signals</h1>
<h1 style="font-size: 14px;">Modified from 'tutorials/beginner_source/dcgan_faces_tutorials.py'</h1>




In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt

from IPython.display import HTML
import matplotlib.animation as animation

In [None]:
# ------ Configuration ------

def initialize(seed=0):
    torch.cuda.empty_cache()
    torch.manual_seed(seed)
    random.seed(seed)


    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        device = 'cuda:0'
        print("CUDA is available, Running on GPU.")
    else:
        device = 'cpu'
        print("CUDA is not available, Running on CPU")
    
    return device

device = initialize(seed=0)

In [None]:
dataroot = "data/PSF" # root directory

workers = 2
batch_size = 64
image_size = 64 # resize image

nc = 3 # Number of channels for input
nz = 12 # Size of latent vector z

ngf = 64 # Size of feature maps in generator
ndf = 64 # Size of feature maps in discriminator

num_epochs = 500
lr = 0.0001
beta1 = 0.5
ngpu = 1

In [None]:
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                           ]))

In [None]:
# ------ Dataloader ------

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Plot Training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=0, normalize=True).cpu(),(1,2,0)))

In [None]:
# ------ Custom weight for G and D ------

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# ------ Generator ------

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)
    
    # def forward(self, input):
    #     x = input
    #     for layer in self.main:
    #         x = layer(x)
    #         print(x.size())
    #     return x
    
netG = Generator(ngpu).to(device)
netG.apply(weights_init)
print(netG)

In [None]:
# ------ Discriminator ------

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Dropout(0.4),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Dropout(0.4),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Dropout(0.4),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Dropout(0.4)
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            # nn.Linear(1, 1)
        )

    def forward(self, input):
        return self.main(input)
    
    # def forward(self, input):
    #     x = input
    #     for layer in self.main:
    #         x = layer(x)
    #         print(x.size())
    #     return x
    
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
print(netD)

In [None]:
criterion = nn.MSELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device) # latent vectors to visualize G progression

real_label = 1.0
fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# ------ Training ------

img_list = []
G_losses = []
D_losses = []
iters = 0

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):

        # ------ Update Discriminator ------

        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        # train D with real 
        output = netD(real_cpu).view(-1)
        errD_real = 0.5 * criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train D with fake 
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)

        output = netD(fake.detach()).view(-1)
        errD_fake = 0.5 * criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        
        errD = errD_real + errD_fake
        optimizerD.step()

        # ------ Update Generator ------
                
        netG.zero_grad()
        label.fill_(real_label)  
        output = netD(fake).view(-1)
        errG = 0.5 * criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step() 

        if i % 500 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 1000 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

    if epoch % 500 == 0:
        torch.save({
            'generator' : netG.state_dict(),
            'discriminator' : netD.state_dict(),
            'optimizerG' : optimizerG.state_dict(),
            'optimizerD' : optimizerD.state_dict(),
            }, 'model/model_epoch_{}.pth'.format(epoch))

torch.save({
            'generator' : netG.state_dict(),
            'discriminator' : netD.state_dict(),
            'optimizerG' : optimizerG.state_dict(),
            'optimizerD' : optimizerD.state_dict(),
            }, 'model/model_final.pth')            



In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="Generator")
plt.plot(D_losses,label="Discriminator")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# Batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-2],(1,2,0)))
plt.show()