In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
import copy
from google.colab import files
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.activation import LeakyReLU
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from tqdm import tqdm_notebook as tqdm
random.seed(42)
import warnings
warnings.filterwarnings("ignore")

In [None]:
#Set up Cuda and folders
cuda = True if torch.cuda.is_available() else False
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = torchvision.datasets.CIFAR10(root=".", train=True, download=True, transform=transform)

Files already downloaded and verified


In [None]:
img_class = 1   # 1 = Cars
idx = list()

for i in range(len(train_set)):         # getting index of all samples in class
  if train_set[i][1] == img_class:
    idx.append(i)

subset = np.zeros([len(idx),3, 32, 32])
for i in range(len(idx)):
  subset[i,:,:,:] = train_set[idx[i]][0]

In [None]:
loader = torch.utils.data.DataLoader(
    subset[:4500], batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    subset[4500:], batch_size=12, shuffle = True,
)

In [None]:
def CutOut(images, cutOutSize=[8,8]):
    
    # Cuts out a portion of size "cutOutSize" from the middle of every image in "images"
    # Cut pixels are set to 1 (white)

    images = images.detach().numpy()                                      # Converting torch tensor to np array
    imgSize = images.shape
    cutImages = copy.deepcopy(images)                                                 
    cutOuts = np.zeros([np.size(images, 0), imgSize[1], cutOutSize[0], cutOutSize[1]])

    xmin = int(imgSize[2]/2 - cutOutSize[0]/2)
    xmax = int(xmin + cutOutSize[0])
    ymin = int(imgSize[3]/2 - cutOutSize[1]/2)
    ymax = int(ymin + cutOutSize[1])
    
    for i in range(len(images)):                                                # Cutting images
        
      cutOuts[i, :, :, :] = images[i, :, xmin:xmax, ymin:ymax]
      cutImages[i, :, xmin:xmax, ymin:ymax] = 1

    cutImages=torch.from_numpy(cutImages)                                       # Converting np array to torch tensor
    cutOuts=torch.from_numpy(cutOuts)
    
    return [cutImages, cutOuts]

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()    
    self.model = torch.nn.Sequential(
        nn.Conv2d(3,64,3,stride = 2, padding = 1), #16x16x64
        nn.LeakyReLU(0.2),
        nn.Conv2d(64,128,3,stride = 2, padding = 1), #8x8x128
        nn.BatchNorm2d(128,0.8),
        nn.LeakyReLU(0.2),
        nn.Conv2d(128,256,3,stride = 2, padding = 1),  #4x4x256
        nn.BatchNorm2d(256,0.8),
        nn.LeakyReLU(0.2),
        nn.Conv2d(256,4096,4,stride=1),
        nn.ConvTranspose2d(4096,256,4,stride=1, padding = 0), #4x4x256
        nn.BatchNorm2d(256,0.8),
        nn.ReLU(),
        nn.ConvTranspose2d(256,128,4,stride=2, padding = 1), #8x8x128
        nn.BatchNorm2d(128,0.8),
        nn.ReLU(),
        nn.Conv2d(128,3,3,1,1),
        nn.Tanh()
      )


  def forward(self, x):
      return self.model(x)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
      super(Discriminator,self).__init__()
      self.model = nn.Sequential(
          nn.Conv2d(3,64,3,1,0),
          nn.LeakyReLU(0.2, inplace=True), #8
          nn.Conv2d(64,128,3,2,1),
          nn.InstanceNorm2d(128),
          nn.LeakyReLU(0.2, inplace=True), #4
          nn.Conv2d(128,1,3,1,0)
      )
  def forward(self,x):
    return self.model(x)

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def save_sample(batch):

    samples = next(iter(test_loader))    # Generate a new set of samples
    masked_samples, cutOut_samples = CutOut(samples)  # Create masked and cutout Samples
    samples = samples.type(Tensor)      # Create a tensor out of it
    masked_samples = masked_samples.type(Tensor)  # Create a tensor out of it
    generated_cutout = generator(masked_samples)  # Generate a cutout
    filled_samples = masked_samples.clone()       # Clone the masked images
    filled_samples[:, :, 12 : 20, 12 : 20] = generated_cutout # Fill the masked area with the generated cutout
    sample = torch.cat((masked_samples.data, filled_samples.data, samples.data), -2) # Create a combination of the different kind of images
    save_image(sample, "images/%d.png" % batch, nrow=6, normalize=True) # Save the images

    

adversarial_loss = torch.nn.MSELoss() # MSE loss
pixelwise_loss = torch.nn.L1Loss()  # L1 loss


generator = Generator() #Initialize Generator
discriminator = Discriminator() #Initialize Discord

# Initialize Cuda
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    pixelwise_loss.cuda()

# Initialize the weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Initialize Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.00008, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00008, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
# Epoch loop
n_epochs = 100
for epoch in range(n_epochs):
    
    # Training loop
    gen_adv_loss, gen_pixel_loss, disc_loss = 0, 0, 0
    tqdm_bar = tqdm(loader, desc=f'Training Epoch {epoch} ', total=int(len(loader)))
    for i, imgs in enumerate(tqdm_bar):

          masked_imgs, masked_parts = CutOut(imgs)
          masked_imgs = masked_imgs.type(Tensor) # Convert into tensor
          masked_parts = masked_parts.type(Tensor) # Convert into tensor

          # Initialize Generator training
          optimizer_G.zero_grad()

          # Generate a batch of images
          for j in range(len(masked_imgs)):
            gen_parts = generator(masked_imgs)

          # Generator Loss
          real = Tensor(imgs.shape[0],1,1,1).fill_(1.0) # Tensor filled with 1 indicating real cutouts
          generator_adv = adversarial_loss(discriminator(gen_parts), real)  # Adversarial loss
          generator_pixel = pixelwise_loss(gen_parts, masked_parts) # Pixel wise loss
          generator_loss = 0.01 * generator_adv + 0.99 * generator_pixel #Adjusting ratio between Adversarial and pixel wise loss

          # Update Generator
          generator_loss.backward()
          optimizer_G.step()

          # Initialize Discriminator training
          optimizer_D.zero_grad()

          # Discriminator loss
          fake = Tensor(imgs.shape[0],1,1,1).fill_(0.0) # Tensor filled with 0 indicating generated cutouts
          discriminator_real = adversarial_loss(discriminator(masked_parts), real)  # Discriminator fed with real cutouts
          discriminator_fake = adversarial_loss(discriminator(gen_parts.detach()), fake) # Discriminator fed with generated cutouts
          discriminator_loss = 0.5 * (discriminator_real + discriminator_fake)  #Adjusting ratio between real and generated cutouts into the discriminator

          # Update Discriminator
          discriminator_loss.backward()
          optimizer_D.step()
          
          
          # TQDM parameters + visualiation
          gen_adv_loss += generator_adv.item()
          gen_pixel_loss += generator_pixel.item()
          disc_loss += discriminator_loss.item()
          tqdm_bar.set_postfix(gen_adv_loss=gen_adv_loss/(i+1), gen_pixel_loss=gen_pixel_loss/(i+1), disc_loss=disc_loss/(i+1))
          
          # Generate and save images every 50th batch
          batch = epoch * len(loader) + i
          if batch % 50 == 0:
              save_sample(batch)
          
          # Save Model
          if i % 50 == 0:          
            torch.save (generator.state_dict(), "saved_models/generator.pth")
            torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")


In [None]:
#For loading a network
generator.load_state_dict(torch.load("generator.pth"))
discriminator.load_state_dict(torch.load("discriminator.pth"))