<a href="https://colab.research.google.com/github/eliavmor/Pytorch-GANs/blob/master/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from google.colab import drive
import os
import os.path as path
drive.mount("/content/drive")
drive_path = "/content/drive/My Drive/Colab Notebooks"
checkpoint_path = os.path.join(drive_path, "checkpoint")
output_path = os.path.join(drive_path, "output_images")
if not path.isdir(checkpoint_path):
    os.mkdir(checkpoint_path)
if not path.isdir(output_path):
    os.mkdir(output_path)    

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear_process = nn.Sequential(
            nn.Linear(100, 128 * 2 * 2)
        )
        self.post_linear_process = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 3),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Upsample(scale_factor=2),

            nn.ConvTranspose2d(128, 64, 3),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(negative_slope=0.2),

            nn.ConvTranspose2d(64, 32, 3),
            nn.BatchNorm2d(32, 0.8),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Upsample(scale_factor=2),

            nn.ConvTranspose2d(32, 1, 5),
            nn.Tanh()
        )
    
    def forward(self, noise):
        x = self.linear_process(noise)
        x = x.view(-1, 128, 2, 2)
        fake_images = self.post_linear_process(x)
        return fake_images

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.step1 = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.BatchNorm2d(6),
            nn.LeakyReLU(0.2),
            nn.Dropout(p=0.25),
            nn.Conv2d(6, 8, 5),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2),
            nn.Dropout(p=0.25),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2),
            nn.Dropout(p=0.25),
            nn.Conv2d(16, 4, 3, stride=2, padding=1),
            nn.BatchNorm2d(4),
            nn.LeakyReLU(0.2),
            nn.Dropout(p=0.25),
        )
        self.step2 = nn.Sequential(
            nn.Linear(4*5*5, 1),
            nn.Sigmoid()
        )
    
    def forward(self, images):
        x = self.step1(images)
        x = x.view(-1, 4*5*5)
        predictions = self.step2(x)
        return predictions
    

In [0]:
# init weights for conv layers and BatchNorm2D
def init_weights(layer):
    if isinstance(layer, nn.ConvTranspose2d):
        layer.weight.data.normal_(0, 0.02)
    elif isinstance(layer, nn.BatchNorm2d):
        layer.weight.data.normal_(1, 0.02)
        layer.bias.data.fill_(0)


In [0]:
def save_images(fake_images, output_path):
    plt.figure(figsize=(15,15))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(fake_images[i, 0, :, :].detach().cpu(), cmap="gray")
    plt.savefig(output_path)
    plt.close()

In [0]:
from torch.optim import Adam

batch_size = 64
epochs = 100

criterion = nn.BCELoss()
D = Discriminator()
G = Generator()

D.apply(init_weights)
# G.apply(init_weights)

run_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if run_gpu else "cpu")

G.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks/checkpoint/mnist_generator_180.pth'))
G = G.to(device)
D = D.to(device)

d_optimizer = Adam(params=D.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = Adam(params=G.parameters(), lr=0.0002, betas=(0.5, 0.999))

test_fake = torch.rand(17 * 100)
test_fake = test_fake.view(-1, 100).to(device)

data = torchvision.datasets.MNIST(root="./data", transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ), download=True)

datal = DataLoader(data, batch_size=batch_size, shuffle=True)

for epoch in range(epochs):
    D_real_err = []
    D_fake_err = []
    G_err = []
    for idx, batch in enumerate(datal):
        images = batch[0]
        batch_size = images.size()[0]
        # =======   Train Discriminator    ==========
        d_optimizer.zero_grad()
        # move images to device
        images = images.to(device)
        # predict real images
        d_real_prediction = D(images)
        
        # create real an fake labels for the given batch_size
        real_labels = torch.ones(batch_size).view(-1, 1).to(device)
        fake_labels = torch.zeros(batch_size).view(-1, 1).to(device)
        
        # calculate real images error
        real_err = criterion(d_real_prediction, real_labels)

        # create random noise batch
        noise = torch.randn(batch_size * 100).view(-1, 100).to(device)

        # generate fake images
        fake_images = G(noise)
        
        # get discriminator answer
        # use detach() on fake images in order to prevent Generator training
        d_fake_prediction = D(fake_images.detach())

        fake_err = criterion(d_fake_prediction, fake_labels)
        err_d = (fake_err + real_err) / 2
        err_d.backward()
        d_optimizer.step()
        # =======   End Train Discriminator    ==========

        # =======   Train Generator   ==========
        g_optimizer.zero_grad()
        fake_images = G(noise)
        d_fake_prediction = D(fake_images.to(device))
        g_fake_err = criterion(d_fake_prediction, real_labels)
        g_fake_err.backward()
        g_optimizer.step()
        # =======   End Train Generator    ==========
    if not (epoch % 10):
        print("Epoch [{}/{}]".format(epoch, epochs))
        print("D real Error {:.4f}".format(np.average(D_real_err)))
        print("D fake Error {:.4f}".format(np.average(D_fake_err)))
        print("G Error {:.4f}".format(np.average(G_err)))
        with torch.no_grad():
            save_images(G(test_fake), "/content/drive/My Drive/Colab Notebooks/output_images/mnist_generator_{}.jpg".format(epoch + 180))
        with open('/content/drive/My Drive/Colab Notebooks/checkpoint/mnist_discriminator_{}.pth'.format(epoch), 'wb') as f:
            torch.save(D.state_dict(), f)
        with open('/content/drive/My Drive/Colab Notebooks/checkpoint/mnist_generator_{}.pth'.format(epoch + 180), 'wb') as f:
            torch.save(G.state_dict(), f)

    

In [0]:
import imageio
def create_animation(images_path):
    ims_path = [os.path.join(images_path, f) for f in os.listdir(images_path) if f.endswith('.jpg')]
    images = []
    with imageio.get_writer(os.path.join(images_path, "gan_result.gif"), mode='I', duration=0.2) as writer:
        for filename in ims_path:
            image = imageio.imread(filename)
            writer.append_data(image)