# DCGAN
Para este notebook se van a crear dos DCGAN, una entrenada con imágenes benignas
y otra DCGAN con imágenes malignas.

In [None]:
%pip install -r requirements.txt --quiet
%pip install -r dev-requirements.txt --quiet

In [1]:
import file_utils

In [2]:
path = "./BreaKHis_v1/histology_slides/"
images_path = file_utils.get_image_path(path)

In [5]:
from torchvision import datasets, transforms
from breakhis import BreastHistology

image_size = 64
ngpu = 1

data_transform = transforms.Compose([
    # Ya que hay algunas imagenes que tienen menor resolución
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 
Solo correr la siguiente celda en caso de querer generar las imágenes en 
un direcotrio nuevo, esto para posteriormente generar las estádisticas necesarias
para calcular el FID.

In [None]:
import torchvision.utils as vutils
import os

save_image_transform = transforms.Compose([
    # Ya que hay algunas imagenes que tienen menor resolución
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

breast_histology_slides = BreastHistology(root_dir=path,
                                          transform=save_image_transform)

os.makedirs("original_images", exist_ok=True)

for index, image in enumerate(iter(breast_histology_slides)):
    filename = f"image_{index}.png"
    vutils.save_image(image[0], os.path.join("original_images", filename))

# Calcular FID score e IS
A continuación se calculan ambas métricas. Se supone que ambas métricas ayudan a calificar
el desempeño que tiene una GAN. [paper](https://arxiv.org/abs/1911.07023)

## FID Score


## Inception Score

In [None]:
!python3 score_infinity.py --path "./original_images/" --out_path "./breakhis_statics_b.npz"

In [None]:
breast_histology_slides = BreastHistology(root_dir=path,
                                          transform=data_transform)

len(breast_histology_slides)

In [None]:
from torch.utils.data import DataLoader

batch_size = 128

breast_data = DataLoader(dataset=breast_histology_slides,
                         batch_size=batch_size,
                         num_workers=2,
                         shuffle=True)

In [None]:
import torch

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
device

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Plot some training images
real_batch = next(iter(breast_data))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
import torch.nn as nn

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input) 100
nz = 512

# Size of feature maps in generator
ngf = 64

In [None]:
from generator import Generator

# Create the generator
netG = Generator(ngpu, nc, nz, ngf).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

# Print the model
print(netG)

In [1]:
from discriminator import Discriminator

# Size of feature maps in discriminator
ndf = 64

# Create the Discriminator
netD = Discriminator(ngpu, ndf, nc).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
    
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

NameError: name 'ngpu' is not defined

In [None]:
import torch.optim as optim

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

num_epochs = 300         

netD.train()
netG.train()

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(breast_data, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(breast_data),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(breast_data)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
import matplotlib
import matplotlib.animation as animation

matplotlib.rcParams['animation.embed_limit'] = 2**128
from IPython.display import HTML


fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

ani.save("animation.gif")

HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(breast_data))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [None]:
torch.save(netD.state_dict(), "breakhis_GAN_Discriminator_512dim_benign.pt")
torch.save(netG.state_dict(), "breakhis_GAN_Generator_512dim_benign.pt")

In [None]:
# Creamos un vector de ruido z
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

# Pasamos el vector de ruido al generador
fake = netG(fixed_noise).detach().cpu()

# Mostramos las imágenes generadas
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Imágenes generadas")
plt.imshow(np.transpose(vutils.make_grid(fake[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# Creamos una carpeta donde guardar las imágenes
dirname = "benign_fake_images"
os.makedirs(dirname, exist_ok=True)

images = []

for i in range(200):
    fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
    fake = netG(fixed_noise).detach().cpu()

    for j in range(64):
        image = fake[j]

        filename = f"image_{i}{j}.png"

        vutils.save_image(image, os.path.join(dirname, filename), normalize=True)

In [None]:
from score_infinity import calculate_FID_infinity_path, calculate_IS_infinity_path

FID_infinity = calculate_FID_infinity_path('breakhis_statics_b.npz', dirname, batch_size)
IS_infinity = calculate_IS_infinity_path(dirname, batch_size)