In [None]:
!mkdir results

In [None]:
!mkdir images

In [None]:
!mkdir train_data

In [None]:
!cp -r ./gan-getting-started/monet_jpg/ ./train_data/

In [None]:
!ls train_data/monet_jpg/

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as T
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL

# setting the manual seed
manualSeed = 1000

print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
device

In [None]:
DATAROOT = "./train_data/"
WORKERS = 2
BATCH_SIZE = 32
OUT_IMG_SIZE = 256
IN_CHANNELS = 3
OUT_CHANNELS = 3
ENCODING_SIZE = 100
GEN_FEATURE_SIZE = 64
DISC_FEATURE_SIZE = 64
NUM_EPOCHS = 100
lr = 0.0002
BETA1 = 0.5

In [None]:
transform = T.Compose([T.Resize(OUT_IMG_SIZE), T.CenterCrop(OUT_IMG_SIZE), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.ImageFolder(root = "./train_data/", transform=transform)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [None]:
# displaying 1 batch of images
real_batch = next(iter(train_loader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# defining the generator

class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is vector of size ENCODING_SIZE
            nn.ConvTranspose2d(ENCODING_SIZE, GEN_FEATURE_SIZE * 8, 16, 1, 0, bias = False),
            nn.BatchNorm2d(GEN_FEATURE_SIZE * 8),
            nn.ReLU(True),
            # (GEN_FEATURE_SIZE * 8) x 16 x 16
            nn.ConvTranspose2d(GEN_FEATURE_SIZE * 8, GEN_FEATURE_SIZE * 4, 4, 2, 1, bias = False),
            nn.BatchNorm2d(GEN_FEATURE_SIZE * 4),
            nn.ReLU(True),
            # (GEN_FEATURE_SIZE * 4) x 32 x 32
            nn.ConvTranspose2d(GEN_FEATURE_SIZE * 4, GEN_FEATURE_SIZE * 2, 4, 2, 1, bias = False),
            nn.BatchNorm2d(GEN_FEATURE_SIZE * 2),
            nn.ReLU(True),
            # (GEN_FEATURE_SIZE * 2) x 64 x 64
            nn.ConvTranspose2d(GEN_FEATURE_SIZE * 2, GEN_FEATURE_SIZE, 4, 2, 1, bias = False),
            nn.BatchNorm2d(GEN_FEATURE_SIZE),
            nn.ReLU(True),
            # (GEN_FEATURE_SIZE) x 128 x 128
            nn.ConvTranspose2d(GEN_FEATURE_SIZE, IN_CHANNELS, 4, 2, 1, bias = False),
            nn.Tanh()
            # (IN_CHANNELS) x 256 x 256
        )

    def forward(self, input):
        return self.main(input)

In [None]:
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is IN_CHANNELS x 256 x 256
            nn.Conv2d(IN_CHANNELS, DISC_FEATURE_SIZE, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace=True),
            # (DISC_FEATURE_SIZE) x 128 x 128
            nn.Conv2d(DISC_FEATURE_SIZE, DISC_FEATURE_SIZE * 2, 4, 2, 1, bias = False),
            nn.BatchNorm2d(DISC_FEATURE_SIZE * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # (DISC_FEATURE_SIZE * 2) x 64 x 64
            nn.Conv2d(DISC_FEATURE_SIZE * 2, DISC_FEATURE_SIZE * 4, 4, 2, 1, bias = False),
            nn.BatchNorm2d(DISC_FEATURE_SIZE * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (DISC_FEATURE_SIZE * 4) x 32 x 32
            nn.Conv2d(DISC_FEATURE_SIZE * 4, DISC_FEATURE_SIZE * 8, 4, 2, 1, bias = False),
            nn.BatchNorm2d(DISC_FEATURE_SIZE * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # (DISC_FEATURE_SIZE * 8) x 16 x 16|
            nn.Conv2d(DISC_FEATURE_SIZE * 8, 1, 16, 1, 0, bias = False),
            nn.Sigmoid() 
        )

    def forward(self, input):
        return self.main(input)

In [None]:
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

In [None]:
criterion = nn.BCELoss()

# fixed_noise = torch.randn(64, ENCODING_SIZE, 1, 1, device=device)
real_label = 1.
fake_label = 0.

disc_optim = optim.Adam(discriminator.parameters(), lr = lr, betas = (BETA1, 0.999))
gen_optim = optim.Adam(generator.parameters(), lr = lr, betas = (BETA1, 0.999))

In [None]:
image_list = []
for j in range(100):
#     img_list = []
    # gen_losses = []
    # disc_losses = []
    epoch = 0

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

    print("------------------------- Starting Training Loop -------------------")

    while epoch < NUM_EPOCHS:
        epoch += 1

        for batch_idx, (data, target) in enumerate(train_loader):

            # TRAINING THE DISCRIMINATOR
            # first on real data then fake data        
            discriminator.zero_grad()

            # formatting the real data
            real_image = data.to(device)
            b_size = real_image.size(0)
            label = torch.full((b_size, ), real_label, dtype=torch.float, device=device)

            # forward pass of real data through discriminator
            output = discriminator(real_image).view(-1)

            # discriminator loss on real data
            disc_real_loss = criterion(output, label)

            # calculating gradients for discriminator in backward pass
            disc_real_loss.backward()
            D_x = output.mean().item()

            # generating batch of random input vectors for Generator
            in_vec = torch.randn(b_size, ENCODING_SIZE, 1, 1, device =  device)

            # generating the fake images and labels
            fake_images = generator(in_vec)
            label.fill_(fake_label)

            # forward pass of fake data through discriminator
            output = discriminator(fake_images.detach()).view(-1)

            # discriminator loss on real data
            disc_fake_loss = criterion(output, label)

            # calculating gradients for discriminator in backward pass
            disc_fake_loss.backward()
            DG_z1 = output.mean().item()


            # error of discriminator
            disc_loss = disc_real_loss + disc_fake_loss

            # Update discriminator
            disc_optim.step()


            # TRAINING THE GENERATOR
            generator.zero_grad()

            # Expected label for each fake image generated by generator is "real_label"
            label.fill_(real_label)

            # forward pass of fake images through updated discriminator
            output = discriminator(fake_images).view(-1)

            # calculating loss for generator
            gen_loss = criterion(output, label)

            # calculating gradients for generator
            gen_loss.backward()
            DG_z2 = output.mean().item()

            # update generator
            gen_optim.step()

            if batch_idx % 10 == 0:
                print(f"epoch : {j} : {epoch/NUM_EPOCHS}\tDisc_loss : {disc_loss.item()}\tGen_loss : {gen_loss.item()}")#\tD(x) : {D_x}\tD(G(z)) : {DG_z1} / {DG_z2}")

    if j % 10:
        image_list.append(vutils.make_grid(fake_images.detach().cpu(), padding=2, normalize=True))
        torch.save(generator.state_dict(), f"./results/generator_{j}.pt")
        torch.save(discriminator.state_dict(), f"./results/discriminator_{j}.pt")

In [None]:
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.imshow(np.transpose(image_list[-1], (1, 2, 0)))

In [None]:
for i in range(100):
    in_vec = torch.randn(80, ENCODING_SIZE, 1, 1, device =  device)
    generated_images = generator(in_vec)
    j = 0 
    for image in generated_images:
        im = np.transpose(image.cpu().detach().numpy(), (1, 2, 0))
        im = (im * 127.5 + 127.5).astype(np.uint8)
        im = PIL.Image.fromarray(im)
        im.save(f"./images/dcgan_{i}")
        j+=1

In [None]:
import shutil
shutil.make_archive("/home/aryan/ML/kaggle/monet-style-images/images", "zip", "/home/aryan/ML/kaggle/monet-style-images/images")

In [None]:
!rm -r train_data/