<a href="https://colab.research.google.com/github/M4mbo/Generative_Adversarial_Network_on_Simpsons_Faces/blob/main/MODEL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def weights_init(m):
    """Reinitialize model weights. GAN authors recommend them to be sampled from N(0,0.2)"""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Reshape(nn.Module):
  """A custom reshape layer."""
  def __init__(self, shape):
    super(Reshape, self).__init__()
    self.shape = shape

  def forward(self, x):
    return x.view(*self.shape)

In [None]:
class Generator(nn.Module):
  """Generator model"""
  def __init__(self, Z):
    super(Generator, self).__init__()
    self.Z = Z

    self.gen_model = nn.Sequential(

        nn.Linear(self.Z, 1024*8*8),
        nn.BatchNorm1d(1024*8*8),
        nn.LeakyReLU(0.2),

        Reshape((-1, 1024, 8, 8)),

        nn.ConvTranspose2d(1024, 512, 5, 2, 1, 0),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(512, 256, 5, 2, 2, 0),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(256, 128, 5, 2, 2, 0),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(128, 64, 5, 2, 2, 1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 3, 5, 1, 1),

    )

  def forward(self, noise):

    x = self.gen_model(noise)

    x = F.tanh(x)

    return x

In [None]:
class Discriminator(nn.Module):
  """Discriminator model"""
  def __init__(self):
    super(Discriminator, self).__init__()
    self.disc_model = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),

        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding=2),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=5, stride=2, padding=2),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.2),


    )
    self.linearization = nn.Sequential(

        nn.Flatten(1,-1),
        nn.Linear(1024*8*8, 1)

    )

    self.sigmoid = nn.Sigmoid()

  def forward(self, x):

    x = self.disc_model(x)

    x = self.linearization(x)

    x = self.sigmoid(x)

    return x

In [None]:
def plot_loss(generator_losses, discriminator_losses):

    # Plot Discriminator Training Loss
    plt.plot(discriminator_losses, label='Discriminator Training Loss', color='blue')
    plt.plot(generator_losses, label='Generator Training Loss', color='red')

    # Set the x-axis and y-axis labels
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    # Set the title and legend
    plt.title('Training Losses')
    plt.legend()

    # Show the grid
    plt.grid(True)

    # Show the plot
    plt.show()



In [None]:
# Helper function to display results
def display_image_grid(images, num_rows, num_cols, title_text):

  fig = plt.figure(figsize=(num_cols*3., num_rows*3.), )
  grid = ImageGrid(fig, 111, nrows_ncols=(num_rows, num_cols), axes_pad=0.15)

  for ax, im in zip(grid, images):
    if im.size(0) == 1:
      if im.dtype == torch.float32 or im.dtype == torch.float64:
        ax.imshow(np.clip(im.permute(1,2,0).numpy(), 0, 1), cmap = 'gray')
      else:
        ax.imshow(np.clip(im.permute(1,2,0).numpy(), 0, 255), cmap = 'gray')
    else:
      if im.dtype == torch.float32 or im.dtype == torch.float64:
        ax.imshow(np.clip(im.permute(1,2,0).numpy(), 0, 1))
      else:
        ax.imshow(np.clip(im.permute(1,2,0).numpy(), 0, 255))

    ax.axis("off")

  plt.suptitle(title_text, fontsize=20)
  plt.tight_layout()
  plt.show()

In [None]:
device = "cuda"
Z = 100
# Visualization Purposes
sample_noise = torch.randn(10, Z).to(device)

In [None]:
def train_GAN(EPOCHS, lrg, lrd, discriminator, generator):

  real_label = 1
  fake_label = 0

  discriminator.train()
  generator.train()

  writer = SummaryWriter("runs/lr_1")

  # Defining the optimizer and loss function here
  optimizer_gen = torch.optim.Adam(generator.parameters(), lr=lrg, betas=(0.5, 0.999))
  optimizer_disc = torch.optim.Adam(discriminator.parameters(), lr=lrd, betas=(0.5, 0.999))

  lr_scheduler1 = optim.lr_scheduler.ExponentialLR(optimizer_disc, gamma=0.99, verbose=True)
  lr_scheduler2 = optim.lr_scheduler.ExponentialLR(optimizer_gen, gamma=0.99, verbose=True)

  loss_fn = nn.BCELoss()

  generator_losses = []
  discriminator_losses = []

  generated_images = []

  ## Training
  for i in range(1,EPOCHS+1):
    pbar = tqdm(train_loader)

    total_gen_loss = 0.0
    total_disc_loss = 0.0
    num_samples = 0

    for b, batch in enumerate(pbar):
      # Every data instance is an input + label pair. We don't need the label
      inputs = batch
      inputs = inputs.to(device)

      inputs = (inputs - 0.5) * 2 # setting data range to [-1,1]

      ############################
      # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
      ###########################

      ## Train with all-real batch
      discriminator.zero_grad()
      # Set up labels
      label = torch.full((inputs.shape[0],1), real_label, dtype=torch.float, device=device)
      # Forward pass real batch through D
      output_real = discriminator(inputs)
      # Calculate loss on all-real batch
      errD_real = loss_fn(output_real, label)
      # Calculate gradients for D in backward pass
      errD_real.backward()
      D_x = output_real.mean().item()

      ## Train with all-fake batch
      noise = torch.randn(inputs.shape[0], Z).to(device)
      # Generate fake image batch with G
      fake = generator(noise)
      label.fill_(fake_label)
      # Classify all fake batch with D
      output_fake = discriminator(fake.detach())
      # Calculate D's loss on the all-fake batch
      errD_fake = loss_fn(output_fake, label)
      # Calculate the gradients for this batch, accumulated (summed) with previous gradients
      errD_fake.backward()
      D_G_z1 = output_fake.mean().item()
      # Compute error of D as sum over the fake and the real batches
      errD = errD_real + errD_fake
      # Update D
      optimizer_disc.step()
      ############################
      # (2) Update G network: maximize log(D(G(z)))
      ###########################
      generator.zero_grad()
      label.fill_(real_label)  # fake labels are real for generator cost
      # Since we just updated D, perform another forward pass of all-fake batch through D
      output_fake = discriminator(fake)
      # Calculate G's loss based on this output
      errG = loss_fn(output_fake, label)
      # Calculate gradients for G
      errG.backward()
      D_G_z2 = output_fake.mean().item()
      # Update G
      optimizer_gen.step()

      total_gen_loss += errG.item()
      total_disc_loss += errD.item()
      num_samples += inputs.size(0)

      pbar.set_description(f"Epoch {i}/{EPOCHS}: ")
      pbar.set_postfix({"generator_loss": errG.item(), "discriminator_loss": errD.item(), "D(x)": D_x, "D(G(z1))": D_G_z1, "D(G(z2))": D_G_z2})

    generator_losses.append(total_gen_loss / num_samples)
    discriminator_losses.append(total_disc_loss / num_samples)


    # Visualization of validation images
    generations = generator(sample_noise).cpu()
    generations = (generations + 1) / 2   #[0,1]
    generations = (generations * 255).clamp(0, 255).to(torch.uint8)
    generated_images.append(generations)
    display_image_grid(generations, 1, 10, f"Generated images at epoch {i}")

    lr_scheduler1.step()
    lr_scheduler2.step()

  return generator_losses, discriminator_losses, generated_images
