# Spectral Normalization:

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import time

import torch
import torchvision
import torch.nn as nn
# import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, models, transforms

In [2]:
! git clone https://github.com/nanopiero/ML_S5

fatal: destination path 'ML_S5' already exists and is not an empty directory.


Let's first define an image generation problem. The following function samples the random image $X$ and the random vector $Z$:

In [3]:
! ls

ML_S5  sample_data


In [2]:
! cp ML_S5/practicals/P5/* .
from utils_P5 import gen_DCGAN, voir_batch2D

In [None]:
# Rectangle proportion in the image:
lambda_rec = 0.0

x, z = gen_DCGAN(6, lambda_rec=lambda_rec)

# Clean versions (individual cells)
fig1 = plt.figure(1, figsize=(36, 6))
voir_batch2D(x, 6, fig1, k=0, min_scale=0, max_scale=1)

fig3 = plt.figure(3, figsize=(36, 6))
voir_batch2D(z, 6, fig3, k=0, min_scale=0, max_scale=1)

**Q1** Instanciate a UNet

In [5]:
n_channels, n_classes, size = 1, 1, 16
from utils_P5 import UNet
netG = UNet(n_channels, n_classes, size).cuda()

**Q2** The Discriminator class is used to encode the discriminator. Instantiate it and use the *weight_init* function to initialize the network's weights. What type of network do you obtain in this way?

In [None]:
ndf = 32
nc = 1

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netD = Discriminator().cuda()
netD.apply(weights_init)

Let's now specify some training parameters (most of them are standard for GANs):

In [7]:
# Fixing the seed (to reproduce results)
manualSeed = 1
torch.manual_seed(manualSeed)

# Number of parallel processes:
workers = 2

# Image size
image_size = 64

# Number of channels
nc = 1

# Batch size
batch_size = 64

# Number of batches per epoch
num_batches = 200
num_epochs = 64

# Learning rate
lr = 0.0002

# Beta1 hyperparameter for Adam
beta1 = 0.5  # Sometimes simply 0.

# Number of GPUs
ngpu = 1

# Cross-entropy
criterion = nn.BCELoss()

# Labels for real and fake images
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [8]:
# To observe how G(z) evolves with z fixed along the training:
_ , fixed_z = gen_DCGAN(batch_size, lambda_rec=lambda_rec)
fixed_z = fixed_z.cuda()

**Q3** Commenter le code ci-dessous après l'avoir lancé:

In [11]:
# to see the convergence :
def plot_PSD_and_Frechet_PSD(list_magnitude_spectrum):

  # for magnitude_spectrum in list_magnitude_spectrum:
  #   # Plot the mean_magnitude_spectrum
  #   plt.imshow(magnitude_spectrum[0], cmap='gray', vmax=500)
  #   plt.colorbar(label='Magnitude')
  #   plt.title('Magnitude Spectrum (Adapted Intensity)')
  #   plt.show()


  for i, magnitude_spectrum in enumerate(list_magnitude_spectrum):
    mean_magnitude_spectrum = magnitude_spectrum[0]
    # Calculate the PSD (Power Spectral Density)
    rows, cols = mean_magnitude_spectrum.shape[1:]
    center_row, center_col = rows // 2, cols // 2
    distances = torch.tensor([[((i - center_row)**2 + (j - center_col)**2)**0.5
                              for j in range(cols)] for i in range(rows)])

    radial_profile = []
    for r in range(min(center_row, center_col)):
        mask = (distances >= r) & (distances < r + 1)
        values = mean_magnitude_spectrum[0, mask]
        radial_profile.append(torch.mean(values).item())
    # Plot the radially averaged spectrum
    if i != 0:
      col = [0,0,1-1/(i+2)]
    else:
      col = 'green'
    plt.plot(radial_profile, color=col)

    # plot sigma
    std_magnitude_spectrum = magnitude_spectrum[1]
    rows, cols = std_magnitude_spectrum.shape[1:]
    center_row, center_col = rows // 2, cols // 2
    distances = torch.tensor([[((i - center_row)**2 + (j - center_col)**2)**0.5
                              for j in range(cols)] for i in range(rows)])

    radial_profile = []
    for r in range(min(center_row, center_col)):
        mask = (distances >= r) & (distances < r + 1)
        values = std_magnitude_spectrum[0, mask]
        radial_profile.append(torch.mean(values).item())

    # Plot the radially averaged spectrum
    if i != 0:
      col = [1-1/(i+2),0,0]
    else:
      col = 'yellow'
    plt.plot(radial_profile, color=col)

  plt.xlabel('Radial Distance from Center')
  plt.ylabel('Average Magnitude')
  plt.legend(['fake (av.)', 'fake (std)'] + ['real (av.)', 'real (std)'])
  plt.yscale('log')
  plt.title('Radially Averaged Fourier Spectrum (PSD)')
  plt.show()

In [None]:
# Lists for stats
img_list = []
G_losses = []
D_losses = []
iters = 0
list_mean_magnitude_spectrum_real = []
list_std_magnitude_spectrum_real = []

print("Starting Training Loop...")

for epoch in range(num_epochs):
    for i in range(num_batches):

        # Sampling X and Z
        x, z = gen_DCGAN(batch_size, lambda_rec=lambda_rec)

        # Real images
        x = x.cuda()

        # White noise
        z = z.cuda()

        # STEP 1: Discriminator optimization

        # Zeroing discriminator gradients
        netD.zero_grad()

        # Discrimination of real images
        D_real = netD(x).view(-1)

        # Gradients with respect to discriminator weights on the batch of real images
        b_size = x.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float).cuda()
        errD_real = criterion(D_real, label)
        errD_real.backward()

        # To calculate accuracy:
        D_real = D_real.mean().item()

        # Generated images
        fake = netG(z.cuda())

        # Discrimination of generated images
        # .detach() -> we do not calculate gradients with respect to netG weights
        # at this step
        D_fake = netD(fake.detach()).view(-1)

        # Gradients with respect to discriminator weights on the batch of generated images
        label.fill_(fake_label)
        errD_fake = criterion(D_fake, label)
        errD_fake.backward()

        # Overall loss:
        errD = errD_real + errD_fake

        # Update discriminator weights
        optimizerD.step()

        # To display accuracy
        D_fake = D_fake.mean().item()

        # STEP 2: Generator optimization
        netG.zero_grad()

        # Regeneration, but gradients calculation is maintained
        D_fake2 = netD(fake).view(-1)

        # Update generator weights
        label.fill_(real_label)
        errG = criterion(D_fake2, label)
        errG.backward()

        # To display accuracy
        D_fake2 = D_fake2.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, num_batches,
                     errD.item(), errG.item(), D_real, D_fake, D_fake2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())


        # Store generated images from "fixed_z" every hundred epochs
        if (iters % 100 == 0) or ((epoch == num_epochs - 1) and (i == num_batches - 1)):
            with torch.no_grad():
                stored_fake = netG(fixed_z).detach().cpu()
            img_list.append(stored_fake)
        iters += 1

        # add magnitude spectrum
        with torch.no_grad():
          try:
            fftx = torch.abs(torch.fft.fft2(x))
            fftfake = torch.abs(torch.fft.fft2(fake))
            mean_magnitude_spectrum_real += fftx.sum(dim=0)
            mean_magnitude_spectrum_fake += fftfake.sum(dim=0)
            std_magnitude_spectrum_real += (fftx**2).sum(dim=0)
            std_magnitude_spectrum_fake += (fftfake**2).sum(dim=0)


          except:
            print('init mean_magnitude_spectrum')
            fftx = torch.fft.fft2(x)
            fftfake = torch.fft.fft2(fake)
            mean_magnitude_spectrum_real = fftx.sum(dim=0)
            mean_magnitude_spectrum_fake = fftfake.sum(dim=0)
            std_magnitude_spectrum_real = (fftx**2).sum(dim=0)
            std_magnitude_spectrum_fake = (fftfake**2).sum(dim=0)


    if epoch % 4 == 3: # last epoch before calculating mean (4 * 200 * 64 = 51200)
      with torch.no_grad():
        # get the mean
        mean_magnitude_spectrum_real /= 4 * num_batches * batch_size
        mean_magnitude_spectrum_fake /= 4 * num_batches * batch_size
        std_magnitude_spectrum_real /= 4 * num_batches * batch_size
        std_magnitude_spectrum_fake /= 4 * num_batches * batch_size
        std_magnitude_spectrum_real -= mean_magnitude_spectrum_real**2
        std_magnitude_spectrum_fake -= mean_magnitude_spectrum_fake**2
        std_magnitude_spectrum_real = torch.sqrt(std_magnitude_spectrum_real)
        std_magnitude_spectrum_fake = torch.sqrt(std_magnitude_spectrum_fake)

        # Shift zero frequency to the center
        mean_magnitude_spectrum_real = torch.fft.fftshift(mean_magnitude_spectrum_real)
        mean_magnitude_spectrum_fake = torch.fft.fftshift(mean_magnitude_spectrum_fake)
        std_magnitude_spectrum_real = torch.fft.fftshift(std_magnitude_spectrum_real)
        std_magnitude_spectrum_fake = torch.fft.fftshift(std_magnitude_spectrum_fake)

        # Calculate magnitude spectrum
        list_mean_magnitude_spectrum_real.append((mean_magnitude_spectrum_real.cpu(),
                                                  std_magnitude_spectrum_real.cpu()))

        plot_PSD_and_Frechet_PSD(
                                [(mean_magnitude_spectrum_fake.cpu(),
                                  std_magnitude_spectrum_fake.cpu()
                                  )] + list_mean_magnitude_spectrum_real)
        del mean_magnitude_spectrum_real
        del mean_magnitude_spectrum_fake
        del std_magnitude_spectrum_real
        del std_magnitude_spectrum_fake

**Q4** Plot the evolution of the cost functions for the generator and discriminator. Visualize the successive images.

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
print(len(img_list))
voir_batch2D(img_list[-1], 8, fig1, k=0, min_scale=0,max_scale=1)

In [None]:
# Testing the Generator:
_ , z = gen_DCGAN(6, lambda_rec = lambda_rec)
z = z.cuda()
fake = netG(z).detach()

fig = plt.figure(1, figsize=(12, 4))
voir_batch2D(fake.cpu(), 14, fig1, k=0, min_scale=0,max_scale=1)

The training is not yet perfect (improvement could be achieved with more epochs), but the generator manages to sample images that are roughly close to the original images. It has started to reproduce intersections between cells. \
One could verify this quantitatively by comparing classical statistics (mean per pixel, standard deviation, etc.) or even spectral densities.

**Q5** Restart training with additional rectangles on the image. Visualize and comment on the results.

In [None]:
# Rectangle proportion in the image :
lambda_rec = 0.00025

x , z = gen_DCGAN(6,lambda_rec = lambda_rec)

# Propre versions (only cells)
fig1 = plt.figure(1, figsize=(36, 6))
voir_batch2D(x, 6, fig1, k=0, min_scale=0,max_scale=1)


fig3 = plt.figure(3, figsize=(36, 6))
voir_batch2D(z, 6, fig3, k=0, min_scale=0,max_scale=1)

In [None]:
manualSeed = 1
num_epochs = 20
torch.manual_seed(manualSeed)

netD = Discriminator().cuda()
netD.apply(weights_init)
netG = UNet(n_channels, n_classes, size).cuda()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Lists for stats
img_list = []
G_losses = []
D_losses = []
iters = 0
list_mean_magnitude_spectrum_real = []

print("Starting Training Loop...")

for epoch in range(num_epochs):
    for i in range(num_batches):
        # Sampling X and Z
        x, z = gen_DCGAN(batch_size, lambda_rec=lambda_rec)

        # Real images
        x = x.cuda()

        # White noise
        z = z.cuda()

        # STEP 1: Discriminator optimization

        # Zeroing discriminator gradients
        netD.zero_grad()

        # Discrimination of real images
        D_real = netD(x).view(-1)

        # Gradients with respect to discriminator weights on the batch of real images
        b_size = x.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float).cuda()
        errD_real = criterion(D_real, label)
        errD_real.backward()

        # To calculate accuracy:
        D_real = D_real.mean().item()

        # Generated images
        fake = netG(z.cuda())

        # Discrimination of generated images
        # .detach() -> we do not calculate gradients with respect to netG weights
        # at this step
        D_fake = netD(fake.detach()).view(-1)

        # Gradients with respect to discriminator weights on the batch of generated images
        label.fill_(fake_label)
        errD_fake = criterion(D_fake, label)
        errD_fake.backward()

        # Overall loss:
        errD = errD_real + errD_fake

        # Update discriminator weights
        optimizerD.step()

        # To display accuracy
        D_fake = D_fake.mean().item()

        # STEP 2: Generator optimization
        netG.zero_grad()
        label.fill_(real_label)

        # Regeneration, but gradients calculation is maintained
        D_fake2 = netD(fake).view(-1)

        # Update generator weights
        errG = criterion(D_fake2, label)
        errG.backward()

        # To display accuracy
        D_fake2 = D_fake2.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, num_batches,
                     errD.item(), errG.item(), D_real, D_fake, D_fake2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        try:
          mean_magnitude_spectrum_real += torch.fft.fft2(x.sum(dim=0))
          mean_magnitude_spectrum_fake += torch.fft.fft2(fake.sum(dim=0))
        except:
          print('init mean_magnitude_spectrum')
          mean_magnitude_spectrum_real = torch.fft.fft2(x.sum(dim=0))
          mean_magnitude_spectrum_fake = torch.fft.fft2(fake.sum(dim=0))


        # Store generated images from "fixed_z" every hundred epochs
        if (iters % 100 == 0) or ((epoch == num_epochs - 1) and (i == num_batches - 1)):
            with torch.no_grad():
                fake = netG(fixed_z.cuda()).detach().cpu()
            img_list.append(fake)
        iters += 1

    if epoch % 4 == 3: # last epoch before calculating mean (4 * 200 * 64 = 51200)
      # get the mean
      mean_magnitude_spectrum_real /= 4 * num_batches * batch_size
      mean_magnitude_spectrum_fake /= 4 * num_batches * batch_size

      # Shift zero frequency to the center
      mean_magnitude_spectrum_real = torch.fft.fftshift(mean_magnitude_spectrum_real)
      mean_magnitude_spectrum_fake = torch.fft.fftshift(mean_magnitude_spectrum_fake)
      # Calculate magnitude spectrum
      mean_magnitude_spectrum_real = torch.abs(mean_magnitude_spectrum_real)
      mean_magnitude_spectrum_fake = torch.abs(mean_magnitude_spectrum_fake)
      list_mean_magnitude_spectrum_real.append(mean_magnitude_spectrum_real.cpu())
      plot_mean_magnitude_spectrum_and_PSD(list_mean_magnitude_spectrum_real +
                                          [mean_magnitude_spectrum_fake.cpu()])
      del mean_magnitude_spectrum_real
      del mean_magnitude_spectrum_fake

In [None]:
print(len(img_list))
voir_batch2D(img_list[-3], 8, fig1, k=0, min_scale=0,max_scale=1)

On the generated images, a clear issue arises: the same rectangles reappear in multiple images of the batch. This problem is called 'mode collapse'. It often occurs with GANs and can complicate their training.

**Exercise n°2** Wasserstein-GANs.

To facilitate the convergence of GANs, several approaches have been explored. In particular:
- Giving the discriminator more time to converge at each step.
- Keep the Lipschitzianity of the discriminator. This option takes its root in an interesting theoretical approach (see the supplementary exercise sheet). It can be done:

  * by constraining the weights of the discriminator to remain within a given interval (see the paper introducing WGANs [(Wasserstein-GANs)](https://arxiv.org/abs/1701.07875).

  * by [gradient penalization](https://arxiv.org/pdf/1704.00028.pdf)   

**Q2** In the following cells, these three approaches are coded. Say where.

In [None]:
nc = 1
ndf = 32

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        return self.main(input)

In [None]:
def calculate_gradient_penalty(model, real_images, fake_images):
    alpha = torch.randn((real_images.size(0), 1, 1, 1)).cuda()
    interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

    model_interpolates = model(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), requires_grad=False).cuda()

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2)
    return gradient_penalty

In [None]:
n_channels, n_classes,size = 1, 1, 16
netG = UNet(n_channels, n_classes, size).cuda()

netD = Discriminator()
netD.apply(weights_init)
netD = netD.cuda()

In [None]:
# Proportion of rectangle in the image:
lambda_rec = 0.00025

# Fixing the seed for reproducibility:
manualSeed = 1
torch.manual_seed(manualSeed)

# Number of parallel processes:
workers = 2

# Image size:
image_size = 64

# Number of channels:
nc = 1

# Batch size:
batch_size = 64

# Number of batches per epoch (for the generator):
num_batches_generator = 200
num_epochs = 30

# Learning rate:
lr = 0.0001

# Beta1 hyperparameter for Adam:
beta1 = 0.  # In the paper introducing gradient penalty

# Number of GPUs:
ngpu = 1

# Cross-entropy & label conventions:
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.

# Gradient penalty (gp) or classic WGAN:
add_gp = True

# Setup Adam optimizers for both G and D:
# If gradient penalty:
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# If not:
# optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
# optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

# Schedulers:
step_size = 31
gamma = 0.2
schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=step_size, gamma=gamma)
schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=step_size, gamma=gamma)

In [None]:
# To observe how G(z) evolves with fixed z during training:
_ ,  fixed_z = gen_DCGAN(batch_size, lambda_rec=lambda_rec)
fixed_z = fixed_z.cuda()

In [None]:
img_list = []
G_losses = []
D_losses = []

n_critic = 5
clip = 0.01

print("Starting Training Loop...")

for epoch in range(num_epochs):
    for i in range(num_batches_generator):
        netG.train()
        for j in range(n_critic):
            x , z = gen_DCGAN(batch_size, lambda_rec = lambda_rec)


            netD.zero_grad()
            real = x.cuda()
            output_real = netD(real)
            fake = netG(z.cuda())
            output_fake = netD(fake.detach())

            # Ici, on limite les gradients du discriminateur:
            if add_gp:
                gradient_penalty = calculate_gradient_penalty(netD,
                                                   real.data, fake.data)
                errD = output_fake.mean() - output_real.mean() + 10 * gradient_penalty

            else :
                errD = output_fake.mean() - output_real.mean()

            errD.backward()

            # Update D
            optimizerD.step()

            if not add_gp:
                for p in netD.parameters():
                    p.data.clamp_(-clip, clip)

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
#        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        fake = netG(z.cuda())
        output_fake = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = - output_fake.mean()
        # Calculate gradients for G
        errG.backward()
#        D_G_z2 =  - output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f'
                  % (epoch+1, num_epochs, i, num_batches_generator,
                     errD.item()))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise

    with torch.no_grad():
        netG.eval()
        fake = netG(fixed_z.cuda()).detach().cpu()
#            img_list.append(vutils.make_grid(fake, padding=2, normalize=False))
    img_list.append(fake)


    schedulerD.step()
    schedulerG.step()


**Q3** Can we still observe mode collapse in these images?

In [None]:
plt.figure(figsize=(15,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")

plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig1 = plt.figure(1)
print(len(img_list))
voir_batch2D(img_list[-1], 8, fig1, k=0, min_scale=0,max_scale=1)



Indeed, the outputs of the generator, while not yet perfect, show no signs of mode collapse.

**Q4** Let's finally see the results of training over several hours. Load the UNet trained for 600 epochs (*netG_600.pt*) and visualize the generated images.

In [None]:
n_channels, n_classes,size = 1, 1, 16
netG_600ep = UNet(n_channels, n_classes, size).cuda()
path_netG = "Ex2_netG_600ep_WGP_lr0001.pt"
netG_600ep.load_state_dict(torch.load(path_netG)['model_state_dict'])
netG_600ep = netG_600ep.cuda()

In [None]:
netG_600ep.eval()


x , z = gen_DCGAN(6, lambda_rec = lambda_rec)

# Generate fake image batch with G


real_and_fakes = [x]
n = 4
for i in range(n):
    _ ,  z = gen_DCGAN(6, lambda_rec = lambda_rec)
    z = z.cuda()
    with torch.no_grad():
        fake = netG_600ep(z).cpu()
    real_and_fakes.append(fake)

real_and_fakes = torch.cat(real_and_fakes,dim=0)
fig1 = plt.figure(4, figsize=(36, 6))
voir_batch2D(real_and_fakes, 6, fig1, k=0, min_scale=0, max_scale=1)

**Exercice n°3** spectral normalization


In [28]:
import torch.nn.utils.spectral_norm as spectral_norm
nc = 1
ndf = 32

class SDiscriminator(nn.Module):
    def __init__(self):
        super(SDiscriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), eps=1e-4),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), eps=1e-4),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), eps=1e-4),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), eps=1e-4),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), eps=1e-4),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
from utils_P5 import gen_DCGAN, voir_batch2D
# Rectangle proportion in the image :
lambda_rec = 0.00025

x, z = gen_DCGAN(6,lambda_rec = lambda_rec)

# Propre versions (only cells)
fig1 = plt.figure(1, figsize=(36, 6))
voir_batch2D(x, 6, fig1, k=0, min_scale=0,max_scale=1)


fig3 = plt.figure(3, figsize=(36, 6))
voir_batch2D(z, 6, fig3, k=0, min_scale=0,max_scale=1)

In [31]:
# To observe how G(z) evolves with z fixed along the training:
_ , fixed_z = gen_DCGAN(batch_size, lambda_rec=lambda_rec)
fixed_z = fixed_z.cuda()

In [None]:
manualSeed = 1
num_epochs = 64
torch.manual_seed(manualSeed)

netD = SDiscriminator().cuda()
netD.apply(weights_init)
netG = UNet(n_channels, n_classes, size).cuda()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Lists for stats
img_list = []
G_losses = []
D_losses = []
iters = 0
list_mean_magnitude_spectrum_real = []

print("Starting Training Loop...")

for epoch in range(num_epochs):
    for i in range(num_batches):
        # Sampling X and Z
        x, z = gen_DCGAN(batch_size, lambda_rec=lambda_rec)

        # Real images
        x = x.cuda()

        # White noise
        z = z.cuda()

        # STEP 1: Discriminator optimization

        # Zeroing discriminator gradients
        netD.zero_grad()

        # Discrimination of real images
        D_real = netD(x).view(-1)

        # Gradients with respect to discriminator weights on the batch of real images
        b_size = x.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float).cuda()
        errD_real = criterion(D_real, label)
        errD_real.backward()

        # To calculate accuracy:
        D_real = D_real.mean().item()

        # Generated images
        fake = netG(z.cuda())

        # Discrimination of generated images
        # .detach() -> we do not calculate gradients with respect to netG weights
        # at this step
        D_fake = netD(fake.detach()).view(-1)

        # Gradients with respect to discriminator weights on the batch of generated images
        label.fill_(fake_label)
        errD_fake = criterion(D_fake, label)
        errD_fake.backward()

        # Overall loss:
        errD = errD_real + errD_fake

        # Update discriminator weights
        optimizerD.step()

        # To display accuracy
        D_fake = D_fake.mean().item()

        # STEP 2: Generator optimization
        netG.zero_grad()
        label.fill_(real_label)

        # Regeneration, but gradients calculation is maintained
        D_fake2 = netD(fake).view(-1)

        # Update generator weights
        errG = criterion(D_fake2, label)
        errG.backward()

        # To display accuracy
        D_fake2 = D_fake2.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, num_batches,
                     errD.item(), errG.item(), D_real, D_fake, D_fake2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        try:
          mean_magnitude_spectrum_real += torch.fft.fft2(x.sum(dim=0))
          mean_magnitude_spectrum_fake += torch.fft.fft2(fake.sum(dim=0))
        except:
          print('init mean_magnitude_spectrum')
          mean_magnitude_spectrum_real = torch.fft.fft2(x.sum(dim=0))
          mean_magnitude_spectrum_fake = torch.fft.fft2(fake.sum(dim=0))


        # Store generated images from "fixed_z" every hundred epochs
        if (iters % 100 == 0) or ((epoch == num_epochs - 1) and (i == num_batches - 1)):
            with torch.no_grad():
                fake = netG(fixed_z.cuda()).detach().cpu()
            img_list.append(fake)
        iters += 1

    if epoch % 4 == 3: # last epoch before calculating mean (4 * 200 * 64 = 51200)
      # get the mean
      mean_magnitude_spectrum_real /= 4 * num_batches * batch_size
      mean_magnitude_spectrum_fake /= 4 * num_batches * batch_size

      # Shift zero frequency to the center
      mean_magnitude_spectrum_real = torch.fft.fftshift(mean_magnitude_spectrum_real)
      mean_magnitude_spectrum_fake = torch.fft.fftshift(mean_magnitude_spectrum_fake)
      # Calculate magnitude spectrum
      mean_magnitude_spectrum_real = torch.abs(mean_magnitude_spectrum_real)
      mean_magnitude_spectrum_fake = torch.abs(mean_magnitude_spectrum_fake)
      list_mean_magnitude_spectrum_real.append(mean_magnitude_spectrum_real.cpu())
      plot_mean_magnitude_spectrum_and_PSD(list_mean_magnitude_spectrum_real +
                                          [mean_magnitude_spectrum_fake.cpu()])
      del mean_magnitude_spectrum_real
      del mean_magnitude_spectrum_fake

In [None]:
print(len(img_list))
voir_batch2D(img_list[-3], 8, fig1, k=0, min_scale=0,max_scale=1)

# Blocs de base pour NowcastNet et DGMR

In [None]:
import re
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm

class GenBlock(nn.Module):
    def __init__(self, fin, fout, opt, use_se=False, dilation=1, double_conv=False):
        super().__init__()
        self.learned_shortcut = (fin != fout)
        fmiddle = min(fin, fout)
        self.opt = opt
        self.double_conv = double_conv

        self.pad = nn.ReflectionPad2d(dilation)
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=0, dilation=dilation)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=0, dilation=dilation)

        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        self.conv_0 = spectral_norm(self.conv_0)
        self.conv_1 = spectral_norm(self.conv_1)
        if self.learned_shortcut:
            self.conv_s = spectral_norm(self.conv_s)

        ic = opt.evo_ic

        self.norm_0 = SPADE(fin, ic)
        self.norm_1 = SPADE(fmiddle, ic)
        if self.learned_shortcut:
            self.norm_s = SPADE(fin, ic)

    def forward(self, x, evo):
        x_s = self.shortcut(x, evo)
        dx = self.conv_0(self.pad(self.actvn(self.norm_0(x, evo))))
        if self.double_conv:
            dx = self.conv_1(self.pad(self.actvn(self.norm_1(dx, evo))))

        out = x_s + dx

        return out

    def shortcut(self, x, evo):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, evo))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)


class SPADE(nn.Module):
    def __init__(self, norm_nc, label_nc):
        super().__init__()

        ks = 3

        self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        nhidden = 64
        ks = 3
        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.ReflectionPad2d(pw),
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=0),
            nn.ReLU()
        )
        self.pad = nn.ReflectionPad2d(pw)
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)

    def forward(self, x, evo):

        normalized = self.param_free_norm(x)
        evo = F.adaptive_avg_pool2d(evo, output_size=x.size()[2:])

        actv = self.mlp_shared(evo)

        gamma = self.mlp_gamma(self.pad(actv))
        beta = self.mlp_beta(self.pad(actv))

        out = normalized * (1 + gamma) + beta

        return out

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

class DoubleConv(nn.Module):

    def   __init__(self, in_channels, out_channels, kernel=3, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels

        self.double_conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            spectral_norm(nn.Conv2d(in_channels, mid_channels, kernel_size=kernel, padding=kernel//2)),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            spectral_norm(nn.Conv2d(mid_channels, out_channels, kernel_size=kernel, padding=kernel//2)),
        )
        self.single_conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel, padding=kernel // 2))
        )

    def forward(self, x):
        shortcut = self.single_conv(x)
        x = self.double_conv(x)
        x = x + shortcut
        return x

class Down(nn.Module):

    def __init__(self, in_channels, out_channels, kernel=3):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels, kernel)
        )

    def forward(self, x):
        x = self.maxpool_conv(x)
        return x


class Up(nn.Module):

    def __init__(self, in_channels, out_channels, bilinear=True, kernel=3):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, kernel=kernel, mid_channels=in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels, kernel)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Up_S(nn.Module):

    def __init__(self, in_channels, out_channels, bilinear=True, kernel=3):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, kernel=kernel, mid_channels=in_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels, kernel)

    def forward(self, x):
        x = self.up(x)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)



In [None]:
#code from https://codeocean.com/capsule/3935105/tree/v1
class Generative_Encoder(nn.Module):
    def __init__(self, n_channels, base_c=64):
        super(Generative_Encoder, self).__init__()
        base_c = base_c
        self.inc = DoubleConv(n_channels, base_c, kernel=3)
        self.down1 = Down(base_c * 1, base_c * 2, 3)
        self.down2 = Down(base_c * 2, base_c * 4, 3)
        self.down3 = Down(base_c * 4, base_c * 8, 3)

    def forward(self, x):
        x = self.inc(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        return x

class Generative_Decoder(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        ic = opt.ic_feature
        self.fc = nn.Conv2d(ic, 8 * nf, 3, padding=1)

        self.head_0 = GenBlock(8 * nf, 8 * nf, opt)

        self.G_middle_0 = GenBlock(8 * nf, 4 * nf, opt, double_conv=True)
        self.G_middle_1 = GenBlock(4 * nf, 4 * nf, opt, double_conv=True)

        self.up_0 = GenBlock(4 * nf, 2 * nf, opt)

        self.up_1 = GenBlock(2 * nf, 1 * nf, opt, double_conv=True)
        self.up_2 = GenBlock(1 * nf, 1 * nf, opt, double_conv=True)

        final_nc = nf * 1

        self.conv_img = nn.Conv2d(final_nc, self.opt.gen_oc, 3, padding=1)
        self.up = nn.Upsample(scale_factor=2)

    def forward(self, x, evo):
        x = self.fc(x)
        x = self.head_0(x, evo)
        x = self.up(x)
        x = self.G_middle_0(x, evo)
        x = self.G_middle_1(x, evo)
        x = self.up(x)
        x = self.up_0(x, evo)
        x = self.up(x)
        x = self.up_1(x, evo)
        x = self.up_2(x, evo)
        x = self.conv_img(F.leaky_relu(x, 2e-1))
        return x

# Discriminator
class DBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ProjBlock, self).__init__()
        self.one_conv = spectral_norm(nn.Conv2d(in_channel, out_channel-in_channel, kernel_size=1, padding=0))
        self.double_conv = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1))
        )

    def forward(self, x):
        x1 = torch.cat([x, self.one_conv(x)], dim=1)
        x2 = self.double_conv(x)
        output = x1 + x2
        return output




#Bout de DGMR (même principe, mais en plus conv3D)
# Ref : https://github.com/hyungting/DGMR-pytorch/blob/master/DGMR/dgmr_layers/DBlock.py

"""
Skilful precipitation nowcasting using deep generative models of radar, from DeepMind
https://arxiv.org/abs/2104.00954
"""

import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import spectral_norm

from .utils import Identity

class DBlock(nn.Module):
    """
    D Block in https://arxiv.org/abs/2104.00954, downsampling 2-D convolution block.
    Args:
        in_channels: int, number of channels of input tensor.
        out_channels: int, number of channels of output tensor.
        relu: bool, whether to apply ReLU function.
        downsample: bool, whether to apply scaling function.
    Return:
        torch.tensor
    """
    def __init__(
        self,
        in_channels:int=None,
        out_channels:int=None,
        relu:bool=True,
        downsample:bool=True
        ):
        super(DBlock, self).__init__()
        Scaling = (nn.AvgPool2d(2, 2) if downsample else Identity())
        ReLU = (nn.LeakyReLU(0.2) if relu else Identity()) #nn.ReLU() if relu else Identity())

        self.conv1x1 = nn.Sequential(
                spectral_norm(nn.Conv2d(in_channels, out_channels, 1, 1, 0), eps=1e-4),
                Scaling
                )
        self.conv3x3 = nn.Sequential(
                ReLU,
                spectral_norm(nn.Conv2d(in_channels, in_channels, 3, 1, 1), eps=1e-4),
                nn.LeakyReLU(0.2), #nn.ReLU(inplace=True),
                spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, 1), eps=1e-4),
                Scaling
                )

    def forward(self, x):
        conv1x1 = self.conv1x1(x)
        conv3x3 = self.conv3x3(x)
        return conv1x1 + conv3x3

class D3Block(nn.Module):
    """
    D Block in https://arxiv.org/abs/2104.00954, downsampling 3-D convolution block.
    Args:
        in_channels: int, number of channels of input tensor.
        out_channels: int, number of channels of output tensor.
        relu: bool, whether to apply ReLU function.
        downsample: bool, whether to apply scaling function.
    Return:
        torch.tensor
    """
    def __init__(
        self,
        in_channels:int=None,
        out_channels:int=None,
        relu:bool=True,
        downsample:bool=True
        ):
        super(D3Block, self).__init__()
        Scaling = (nn.AvgPool3d(2, 2) if downsample else Identity())
        ReLU = (nn.LeakyReLU(0.2) if relu else Identity()) #nn.ReLU() if relu else Identity())

        self.conv1x1 = nn.Sequential(
                spectral_norm(nn.Conv3d(in_channels, out_channels, 1, 1, "same"), eps=1e-4),
                Scaling
                )
        self.conv3x3 = nn.Sequential(
                ReLU,
                spectral_norm(nn.Conv3d(in_channels, in_channels, 3, 1, "same"), eps=1e-4),
                nn.LeakyReLU(0.2), #nn.ReLU(inplace=True),
                spectral_norm(nn.Conv3d(in_channels, out_channels, 3, 1, "same"), eps=1e-4),
                Scaling
                )

    def forward(self, x):
        conv1x1 = self.conv1x1(x)
        conv3x3 = self.conv3x3(x)
        return conv1x1 + conv3x3



# Bout DGMR :
class SpatialDiscriminator(nn.Module):
    def __init__(
        self,
        n_frame: int=10,
        debug: bool=False
        ):
        super().__init__()
        self.n_frame = n_frame
        self.debug = debug

        self.avgpooling = nn.AvgPool2d(2)
        self.d_blocks = nn.ModuleList([
                DBlock(4, 48, relu=False, downsample=True), # 4 -> (3 * 4) * 4 = 48
                DBlock(48, 96, downsample=True), # 48 -> (6 * 4) * 4 = 96
                DBlock(96, 192, downsample=True), # 96 -> (12 * 4) * 4 = 192
                DBlock(192, 384, downsample=True), # 192 -> (24 * 4) * 4 = 384
                DBlock(384, 768, downsample=True), # 384 -> (48 * 4) * 4 = 768
                DBlock(768, 768, downsample=False) # 768 -> 768, no downsample no * 4
                ])
        self.linear = nn.Sequential(
                nn.BatchNorm1d(768),
                spectral_norm(nn.Linear(768, 1))
                )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = x.unsqueeze(2)
        B, N, C, H, W = x.shape # batch_size, total_frames, channel=1, height, width
        indices = random.sample(range(N), self.n_frame)
        x = x[:, indices, :, :, :]
        if self.debug: print(f"Picked x: {x.shape}")
        x = x.view(B*self.n_frame, C, H, W)
        if self.debug: print(f"Reshaped: {x.shape}")
        x = self.avgpooling(x)
        if self.debug: print(f"Avg pool: {x.shape}")
        x = space2depth(x)
        if self.debug: print(f"S2Dshape: {x.shape}")

        for i, block in enumerate(self.d_blocks):
            x = block(x)
            if self.debug: print(f"D block{i}: {x.shape}")

        # sum pooling
        x = self.relu(x)
        x = torch.sum(x, dim=(-1, -2))
        if self.debug: print(f"Sum pool: {x.shape}")

        x = self.linear(x)
        if self.debug: print(f"Linear : {x.shape}")

        x = x.view(B, self.n_frame, -1)
        if self.debug: print(f"Reshaped: {x.shape}")

        x = torch.sum(x, dim=1)
        if self.debug: print(f"Sum up : {x.shape}")

        return x

class TemporalDiscriminator(nn.Module):
    def __init__(
        self,
        crop_size: int=256,
        debug: bool=False
        ):
        super().__init__()
        self.crop_size = crop_size
        self.debug = debug

        self.avgpooling = nn.AvgPool3d(2)
        self.d3_blocks = nn.ModuleList([
                D3Block(4, 48, relu=False, downsample=True), # C: 4 -> 48, T -> T/2
                D3Block(48, 96, downsample=True) # C: 48 -> 96, T/2 -> T/4 (not exactly the same as DGMR)
                ])
        self.d_blocks = nn.ModuleList([
                DBlock(96, 192, downsample=True), # 96 -> (12 * 4) * 4 = 192
                DBlock(192, 384, downsample=True), # 192 -> (24 * 4) * 4 = 384
                DBlock(384, 768, downsample=True), # 384 -> (48 * 4) * 4 = 768
                DBlock(768, 768, downsample=False) # 768 -> 768, no downsample no * 4
                ])
        self.linear = nn.Sequential(
                nn.BatchNorm1d(768),
                spectral_norm(nn.Linear(768, 1))
                )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = random_crop(x, size=self.crop_size).to(x.device)
        x = x.unsqueeze(1)
        B, C, T, H, W = x.shape

        x = x.permute(0, 2, 1, 3, 4).view(B*T, C, H, W) # -> B, T, C, H, W
        if self.debug: print(f"Cropped : {x.shape}")

        x = space2depth(x) # B*T, C, H, W
        x = x.view(B, T, -1, x.shape[-2], x.shape[-1]).permute(0, 2, 1, 3, 4) # -> B, C, T, H, W
        if self.debug: print(f"S2Dshape: {x.shape}")

        for i, block3d in enumerate(self.d3_blocks):
            x = block3d(x)
            if self.debug: print(f"3D block: {x.shape}")

        B, C, T, H, W  = x.shape
        x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)
        if self.debug: print(f"Reshaped: {x.shape}")

        for i, block in enumerate(self.d_blocks):
            x = block(x)
            if self.debug: print(f"D block{i}: {x.shape}")

        # sum pooling
        x = self.relu(x)
        x = torch.sum(x, dim=(-1, -2))
        if self.debug: print(f"Sum pool: {x.shape}")

        x = self.linear(x)
        if self.debug: print(f"Linear : {x.shape}")

        x = x.view(B, T, -1)
        if self.debug: print(f"Reshaped: {x.shape}")

        x = torch.sum(x, dim=1)
        if self.debug: print(f"Sum up : {x.shape}")

        return x

class DGMRDiscriminators(nn.Module):
    def __init__(
        self,
        n_frame: int=10,
        crop_size: int=128
        ):
        super().__init__()
        self.spatial_discriminator = SpatialDiscriminator(n_frame=n_frame)
        self.temporal_discriminator = TemporalDiscriminator(crop_size=crop_size)

    def forward(self, x, y):
        inputs = torch.cat((x, y), dim=1)
        s_score = self.spatial_discriminator(inputs)
        t_score = self.temporal_discriminator(inputs)
        return torch.cat((s_score, t_score), dim=0)



def space2depth(
    x: torch.tensor=None,
    factor: int=2
    ):
    """
    Relocate pixels at (H, W) dimension to channel.
    Args:
        x: torch.tensor, tensor to be transformed.
        factor: int, factor of size reduction.
    """
    B, C, H, W = x.shape
    x = nn.Unfold(factor, stride=factor)(x)
    return x.view(B, C * factor ** 2, H // factor, W // factor)

def depth2space(
    x: torch.tensor=None,
    factor: int=2
    ):
    """
    Relocate pixels at channel to (H, W) dimension.
    Args:
        x: torch.tensor, tensor to be transformed.
        factor: int, factor of size expansion.
    """
    B, C, H, W = x.shape
    x = x.view(B, C, H*W)
    x = nn.Fold((H*factor, W*factor), kernel_size=(factor, factor), stride=factor)(x)
    return x

class Space2Depth(nn.Module):
    """
    Relocate pixels at (H, W) dimension to channel.
    See space2depth.
    """
    def __init__(self, *args):
        super(Space2Depth, self).__init__()
    def forward(self, x):
        if len(x.shape) == 4:
            return space2depth(x)
        if len(x.shape) == 5:
            B, C, T, H, W = x.shape
            x = x.permute(0, 2, 1, 3, 4) # B, T, C, H, W
            x = x.reshape(B*T, C, H, W)
            x = space2depth(x)
            x = x.view(B,  T, -1, x.shape[-2], x.shape[-1])
            x = x.permute(0, 2, 1, 3, 4) # B, C, T, H, W
            return x

class Identity(nn.Module):
    def __init__(self, **kwargs):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

def random_crop(x, size=128, padding=False):
    B, C, H, W = x.shape
    if padding:
        # TODO: add padding=True method
        pass
    else:
        h_idx = random.randint(0, H-size)
        w_idx = random.randint(0, W-size)
        x = x[:, :, h_idx:h_idx+size, w_idx:w_idx+size]
    return x

**Exercice n°4** A conditional GAN.  




In this exercise, the goal is to implement a conditional Wasserstein-GAN. Once again, theoretical aspects are set aside; the objective is solely to construct the training loop.
The context is as follows: we have a set of images representing a domain 𝒟.
The traditional GAN generates new images from 𝒟. In this exercise, we will generate images compatible with a given list of pixel values a priori.

The following cells allow visualization of the available dataset.


In [None]:
# Proportion of pixels preserved in yi:
obs_density = 0.005

x, y, z = gen_condDCGAN(6, obs_density)

# Full images xi
fig1 = plt.figure(1, figsize=(36, 6))
visualize_2D_batch(x, 6, fig1, k=0, min_scale=0, max_scale=1)

# Fragmentary images yi: a few pixels randomly sampled from xi
fig2 = plt.figure(2, figsize=(36, 6))
visualize_2D_batch(y, 6, fig2, k=0, min_scale=-0.2, max_scale=1)

# zi: sample from a centered reduced Gaussian vector
fig3 = plt.figure(3, figsize=(36, 6))
visualize_2D_batch(z, 6, fig3, k=0, min_scale=0, max_scale=1)

**Q1** Drawing inspiration from the previous exercise, complete the training loop and run it for ten epochs:

In [None]:
# SGD Setup
batch_size = 128
num_batches_generator = 200
num_epochs = 10


# Optimizer Parameters
lr = 0.0005
beta1 = 0.  # SGD momentum

# nn setup
ndf = 32
n_channels = 2
n_classes = 1
size = 16

netG = UNet(n_channels, n_classes, size).cuda()
netD = Discriminator(n_channels).cuda()

# optimizers
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
real_label = 1.
fake_label = 0.

In [None]:
# To keep track of generated images from a fixed sample of $z_i$:
fixed_x, fixed_y, fixed_z = gen_condDCGAN(8, p=dens_obs)

# Fixed input for the generator:
fixed_yz = torch.cat((fixed_y, fixed_z), dim=1).cuda()

# Lists
img_list = []
G_losses = []
D_losses = []

# Other hyperparameters
n_critic = 5
clip = 0.01

In [None]:
criterion = nn.BCELoss()
print("Starting Training Loop...")

# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i in range(num_batches_generator):

        ############################
        # (1) Maximization of log(D(x)) + log(1 - D(G(z)))
        ###########################
        netG.train()
        # Here, we perform multiple (n_critic) optimization steps for the discriminator.
        for j in range(n_critic):

            x, y, z = gen_condDCGAN(batch_size, p=dens_obs)

            # Move to GPU
            x = x.cuda()
            y = y.cuda()
            z = z.cuda()

            # Concatenations:
            xy = torch.cat((x, y))
            yz = torch.cat((y, z))

            output_xy = netD(xy)

            fake = netG(yz)
            fake = fake.detach()

            fakey = torch.cat((fake, y), dim=1)
            output_fakey = netD(fakey)

            # Regularization by gradient penalty
            gradient_penalty = calculate_gradient_penalty(netD, xy.data, fakey.data)

            # Calculate discriminator error and update gradients:
            label = torch.full((xy.size(0),), real_label, dtype=torch.float).cuda()
            err_D_real = criterion(output_xy.view(-1), label)
            errD_real.backward()
            label.fill_(fake_label)
            errD_fake = criterion(output_fakey, label)
            errD_fake.backward()

            errD = err_D_real + errD_fake  # In case we want to store it later

            optimizerD.step()

        ############################
        # Maximization of log(D(G(z)))
        ###########################
        netG.zero_grad()

        fake = netD(fake).view(-1)
        fakey = torch.cat((fake, y), dim=1)

        output_fakey = netD(fakey)

        errG = -output_fakey.mean()
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f'
                  % (epoch + 1, num_epochs, i, num_batches_generator,
                     errD.item()))

        # Record losses
        G_losses.append(errG.item())
        D_losses.append(-errD.item())

    with torch.no_grad():
        netG.eval()
        fake = netG(fixed_yz.cuda()).detach().cpu()

    img_list.append(fake)

**Q2** Visualize some images and comment on them.

In [None]:
print(len(img_list))
voir_batch2D(img_list[-1], 8, fig1, k=0, min_scale=0,max_scale=1)

**Q3** To obtain a GAN that takes into account the condition contained in $y_i$, it is necessary to push the training further. The file *netG_180ep_WGP_scheduler75_lr005.pt* contains the weights obtained after training for 300 epochs. Load these weights and visualize several images for the same inputs $x_i$ and $z_i$. Check the coherence and draw conclusions.

In [None]:
weights = torch.load('netG_180ep_WGP_scheduler75_lr005.pt')
netG = UNet(2, 1, 16).cuda()
netG.load_state_dict(weights['model_state_dict'])

In [None]:
netG.eval()

x , y , z = gen_condDCGAN(6, p = dens_obs)

xy = torch.cat((x,y))
real_and_fakes = [xy]
n=6
for i in range(n):
    _ , y , z = gen_condDCGAN(6, p = dens_obs)
    y = y.cuda()
    z = z.cuda()
    yz = torch.cat((y,z),dim=1)
    with torch.no_grad():
      fake = netG(yz).cpu()
    real_and_fakes.append(fake)

real_and_fakes = torch.cat(real_and_fakes,dim=0)
fig1 = plt.figure(4, figsize=(36, 6))
voir_batch2D(real_and_fakes, 6, fig1, k=0, min_scale=0, max_scale=1)