<a href="https://colab.research.google.com/github/kaiju8/GANs-Implemented/blob/main/DCGANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms.functional as F

from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import DataLoader

from torchsummary import summary

import numpy as np

import matplotlib.pyplot as plt

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim, channels):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            
            nn.ConvTranspose2d(noise_dim, 1024, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),

            nn.ConvTranspose2d(1024, 512, kernel_size=4,stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.gen(x)

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m , (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
lr = 2e-4  
z_dim = 100 
img_dim = 64
channels_img = 1
batch_size = 128
num_epochs = 10

In [None]:
disc = Discriminator(channels_img).to(device)
gen = Generator(z_dim, channels_img).to(device)

In [None]:
print(summary(disc, input_size = (channels_img, img_dim, img_dim), batch_size = -42))
print(summary(gen, input_size = (z_dim, 1, 1), batch_size = -42))

In [None]:
transforms = transforms.Compose(
    [
        transforms.Resize(img_dim),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(channels_img)],[0.5 for _ in range(channels_img)]),     
    ]
) 

In [None]:
dataset = datasets.MNIST(root = "dataset", transform = transforms, download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True)

In [None]:
initialize_weights(disc)
initialize_weights(gen)

opt_disc = optim.Adam(disc.parameters(), lr = lr, betas=(0.5,0.999))
opt_gen = optim.Adam(gen.parameters(), lr = lr, betas=(0.5,0.999))

In [None]:
criterion = nn.BCELoss()

In [None]:
fixed_noise = torch.randn((32, z_dim, 1, 1)).to(device)

In [None]:
def generate_img(generator, fixed_noise, channels, img_dim):
    
    fake = generator(fixed_noise).reshape(-1, channels, img_dim, img_dim)
    img_grid = torchvision.utils.make_grid(fake, normalize=True)
    return img_grid

In [None]:
def generate_interpolation(generator, z_dim, channels, img_dim):

    point_1 = torch.randn((1, z_dim, 1, 1)).to(device)
    point_2 = torch.randn((1, z_dim, 1, 1)).to(device)

    interpolated = point_1.detach().clone()
    for i in range(1, 16, 1):
        inter = torch.lerp(point_1, point_2,(i/15.0)).to(device)
        interpolated = torch.cat((interpolated, inter), 0).to(device)
        
    imgs = gen(interpolated).reshape(-1, channels_img, img_dim, img_dim)
    img_grid = torchvision.utils.make_grid(imgs, normalize=True)
    return img_grid

In [None]:
def show_grid(imgs): #Show function from documentation
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
def train_disc(real_img, fake_img, optim):

    optim.zero_grad()

    #Discriminator max log(D(real)) + log(1 - D(G(z)))

    disc_real = disc(real_img).reshape(-1)
    loss_real = criterion(disc_real, torch.ones_like(disc_real))

    disc_fake = disc(fake_img).reshape(-1)# detach for generator stuff or a
    loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    loss_D = (loss_real + loss_fake)/2

    loss_D.backward()
    
    opt_disc.step()

    return loss_D

def train_gen(fake_img, optim):

    optim.zero_grad()
    
    #Discriminator min log(1 - D(G(z))) but better to max log(D(G(z)))
    output = disc(fake_img).reshape(-1)
    loss_G = criterion(output, torch.ones_like(output))

    loss_G.backward()

    opt_gen.step()

    return loss_G


In [None]:
losses_g = []
losses_d = []

for epoch in range(num_epochs):

    loss_d = 0.0
    loss_g = 0.0

    for batch_idx, (real, _) in enumerate(loader):

        real = real.to(device)
        batch_size = real.shape[0]


        noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake = gen(noise).detach()

        loss_d = train_disc(real, fake, opt_disc)
        losses_d.append(loss_d.detach().cpu())


        noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake = gen(noise)

        loss_g = train_gen(fake, opt_gen)
        losses_g.append(loss_g.detach().cpu())

###################################################
        #if batch_idx%100 == 0:
        print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} Discriminator loss: {loss_d:.4f}, Generator loss: {loss_g:.4f}")

        with torch.no_grad():
            show_grid(generate_img(gen, fixed_noise, channels_img, img_dim))
####################################################

show_grid(generate_img(gen, fixed_noise, channels_img, img_dim))

In [None]:
sample_img = generate_img(gen, fixed_noise, channels_img, img_dim)
show_grid(sample_img)
save_image(sample_img, "result.png")

In [None]:
interpolation_img = generate_interpolation(gen, z_dim, channels_img, img_dim)
show_grid(interpolation_img)
save_image(interpolation_img, "interpolation.png")

In [None]:
plt.figure()
plt.plot(losses_g, label='Generator loss')
plt.plot(losses_d, label='Discriminator Loss')
plt.legend()
plt.savefig("loss.png")
plt.show()