<a href="https://colab.research.google.com/github/barakmam/super-resolution/blob/main/cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# imports:
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import imageio
 
# pytorch:
import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad as torch_grad
from torchvision.utils import make_grid

# for reproducibility:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed);

In [None]:
# choose device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print("Device: {}".format(torch.cuda.get_device_name(0)))
print("Device type: {}".format(device))

In [None]:
dataset_path = "./data/"
models_path = "./models/"
print_gif = True

# model specific parameters:
lr = 2e-4
betas = (.0001, .9)
num_epochs = 50

# other parameters:
batch_size = 256
n_channels = 128
Lambda = 10
image_size = (28, 28)
latent_dim = 128
discriminator_iterations = 1

# which models to run:
model_inds_to_run = [0, 1]  # 0 - WGAN, 1 - DCGAN

os.makedirs(models_path, exist_ok=True)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()])

train_data = torchvision.datasets.FashionMNIST('./fashion_data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.FashionMNIST('./fashion_data', train=False, transform=transform)
dataset = torch.utils.data.ConcatDataset([train_data, test_data])
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

classes = train_data.classes

num_classes = int(torch.max(train_data.targets) + 1)
assert(len(classes) == num_classes)

print("Number of classes: {}".format(num_classes))
print("Dataset size: {}".format(len(dataset)))

In [None]:
np.random.seed(seed)
torch.manual_seed(seed)

rows = 5
cols = 5

sample_dataloader = torch.utils.data.DataLoader(dataset, batch_size=rows*cols, shuffle=True)

fig = plt.figure(figsize=(8, 8))
samples, labels = next(iter(sample_dataloader))
for i in range(samples.size(0)):
    ax = fig.add_subplot(rows, cols, i + 1)
    ax.imshow(samples[i].data.cpu().numpy().squeeze(), cmap='gray')
    ax.set_title(classes[labels[i]])
    ax.set_axis_off()

In [None]:
class Generator(nn.Module):
    """ It is mainly based on the mobile net network as the backbone network generator.
    Args:
        image_size (int): The size of the image.
        channels (int): The channels of the image. 
        num_classes (int): Number of classes for dataset. 
    """

    def __init__(self, image_size, channels, num_classes):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.channels = channels

        self.label_embedding = nn.Embedding(num_classes, num_classes)

        self.main = nn.Sequential(
            nn.Linear(100 + num_classes, 128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Linear(1024, channels * image_size * image_size),
            nn.Tanh()
        )

        # Initializing all neural network weights.
        self._initialize_weights()

    def forward(self, inputs: torch.Tensor, labels: list = None) -> torch.Tensor:
        """
        Args:
            inputs (tensor): input tensor into the calculation.
            labels (list):  input tensor label.
        Returns:
            A four-dimensional vector (N*C*H*W).
        """

        conditional_inputs = torch.cat([inputs, self.label_embedding(labels)], dim=-1)
        out = self.main(conditional_inputs)
        out = out.reshape(out.size(0), self.channels, self.image_size, self.image_size)

        return out

    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                m.weight.data *= 0.1
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                m.weight.data *= 0.1
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                m.weight.data *= 0.1
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

# class Generator(nn.Module):
#     def __init__(self, im_size, latent_dim, n_channels, device='cpu'):
#         super(Generator, self).__init__()
#         self.im_size = im_size
#         self.latent_dim = latent_dim 
#         self.n_channels = n_channels
#         self.device = device

#         assert ((self.im_size[0] + 4) % 8 == 0) and ((self.im_size[1] + 4) % 8 == 0), "invalid input dimensions"

#         # build nn:
#         self.pre = nn.Sequential(
#             nn.Linear(self.latent_dim, self.n_channels * ((self.im_size[0] + 4) // 8) * ((self.im_size[1] + 4) // 8)),
#             nn.ReLU())
#         # input size: (n_channels, 4, 4)
#         self.conv1 = nn.Sequential(
#             nn.ConvTranspose2d(self.n_channels, self.n_channels, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(self.n_channels),
#             nn.ReLU())
#         # input size: (n_channels, 8, 8)
#         self.conv2 = nn.Sequential(
#             nn.ConvTranspose2d(self.n_channels, self.n_channels, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(self.n_channels),
#             nn.ReLU())
#         # input size: (n_channels, 16, 16)
#         self.out = nn.Sequential(
#             nn.ConvTranspose2d(self.n_channels, 1, kernel_size=2, stride=2, padding=2),
#             nn.Tanh())
#         # output size: (1, 28, 28)

#     def forward(self, input, condition):
#         output = self.pre(input)
#         output = output.view(-1, self.n_channels, ((self.im_size[0] + 4) // 8), ((self.im_size[1] + 4) // 8))
#         output = self.conv1(output)
#         output = self.conv2(output)
#         output = self.out(output)
#         return output

#     def sample(self, num_samples):
#         z = torch.randn((num_samples, self.latent_dim)).to(self.device)
#         gen_images = self.forward(z)
#         return gen_images

In [None]:
class Discriminator(nn.Module):
    def __init__(self, im_size, latent_dim, n_channels): 
        super(Discriminator, self).__init__()
        self.im_size = im_size
        self.latent_dim = latent_dim
        self.n_channels = n_channels

        # build nn:
        self.conv = nn.Sequential(
            # input size: (1, 28, 28)
            nn.Conv2d(1, self.n_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            # input size: (128, 14, 14)
            nn.Conv2d(self.n_channels, self.n_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            # input size: (128, 7, 7)
            nn.Conv2d(self.n_channels, self.n_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            # input size: (128, 7, 7)
            nn.Conv2d(self.n_channels, self.n_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            # input size: (128, 7, 7)
            nn.AvgPool2d(kernel_size=2))
        # input size: (128, 3, 3)
        self.linear = nn.Linear(self.n_channels * (im_size[0] // 8) * (im_size[1] // 8), 1)
        # output size: (1)

    def forward(self, input):
        batch_size = input.size()[0]
        output = self.conv(input)
        output = output.view(batch_size, -1)
        output = self.linear(output)
        return output

In [None]:
class GAN(nn.Module):
    def __init__(self, im_size, latent_dim, n_channels, Lambda, lr, betas, clip=None, device=torch.device("cpu")):
        super(GAN, self).__init__()
        self.generator = Generator(im_size=im_size, latent_dim=latent_dim, n_channels=n_channels, isWgan=isWgan, device=device).to(device)
        self.discriminator = Discriminator(im_size=im_size, latent_dim=latent_dim, n_channels=n_channels, isWgan=isWgan).to(device)
        self.G_opt = optim.Adam(self.generator.parameters(), lr=(lr if isWgan else lr*50), betas=betas)
        self.D_opt = optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas)
        self.losses = {'G': [], 'D': [], 'GP': []}
        self.latent_dim = latent_dim
        self.Lambda = Lambda
        self.clip = clip
        self.device = device

    def discriminator_iteration(self, data):
        batch_size = data.size()[0]
        # forward fake and real data:
        fake_data = self.generator.sample(batch_size)
        real_data = data.to(self.device)
        d_fake = self.discriminator(fake_data.detach())
        d_real = self.discriminator(real_data)
        
        # calculate loss:
        gradient_penalty = self.calculate_gradient_penalty(real_data, fake_data)
        d_loss = d_fake - d_real + gradient_penalty
        d_loss = d_loss.mean()
        
        self.losses['GP'].append(gradient_penalty.data)
        self.losses['D'].append(d_loss.data)

        # optimize:
        self.D_opt.zero_grad()
        d_loss.backward()
        self.D_opt.step()

        return d_loss.data

    def generator_iteration(self, batch_size):
        fake_data = self.generator.sample(batch_size)
        # forward fake samples:
        d_fake = self.discriminator(fake_data)
        g_loss = -d_fake.mean()
        self.losses['G'].append(g_loss.data)

        # optimize:
        self.G_opt.zero_grad()
        g_loss.backward()
        self.G_opt.step()
        return g_loss.data

    def calculate_gradient_penalty(self, real_data, fake_data):
        batch_size = real_data.size()[0]
        alpha_size = [1 for _ in real_data.size()]
        alpha_size[0] = batch_size
        alpha = torch.rand(alpha_size, device=self.device)
        
        interpolated_data = alpha * real_data + (1 - alpha) * fake_data
        interpolated_out = self.discriminator(interpolated_data)
        gradients = torch_grad(outputs=interpolated_out, inputs=interpolated_data,
                               grad_outputs=torch.ones(interpolated_out.size(), device=self.device),
                               create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(batch_size, -1)
        gradients_norm = gradients.norm(2, dim=1)
        return ((gradients_norm - 1) ** 2).mean()*self.Lambda

In [None]:
# WGAN
wgan = GAN(image_size, latent_dim, n_channels, Lambda, lr, betas, device=device)

models = [
    {
        "name": "WGAN",
        "model": wgan,
        "filename": models_path + "wgan.pt",
        "num_epochs": num_epochs
    }
]

models_to_run = [models[ind] for ind in model_inds_to_run]

In [None]:
for model_info in models_to_run:
    model = model_info["model"]
    num_epochs = model_info["num_epochs"]

    if print_gif:
        # fix latents to see how image generation improves during training:
        fixed_latents = torch.randn((128, model.generator.latent_dim)).to(device)
        gif_images = []

    # train:
    print("----------------- Training model \"{}\" -----------------".format(model_info["name"]))
    for epoch in range(num_epochs):
        start_time = time.time()

        # run discriminator and generator:
        for it, (data, _) in enumerate(data_loader):
            d_loss = model.discriminator_iteration(data)
            if (it + 1) % discriminator_iterations == 0:
                g_loss = model.generator_iteration(batch_size)
        
        # print progress:
        end_time = time.time()
        print('[{}] [Epoch {}/{}] -> G Loss: {:.3f}, D Loss: {:.3f}, Time: {:.3f}'.format(
            model_info["name"], epoch + 1, num_epochs, g_loss, d_loss, end_time - start_time))
        
        if print_gif:
            img_grid = make_grid(model.generator(fixed_latents).cpu().data)
            img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
            img_grid = (255 * (img_grid - img_grid.min()) / (img_grid.max() - img_grid.min())).astype(np.uint8)
            gif_images.append(img_grid)

    if print_gif:
        imageio.mimsave(models_path + model_info["name"] + "_results.gif", gif_images)

    # save model checkpoint:
    torch.save(model.state_dict(), model_info["filename"])
    print("-----------------------------------------------------------")
    print() 

print("Finished training all models!")

In [None]:
# plot Loss of generator and discriminator:
fig, ax_array = plt.subplots(len(models_to_run), 1, figsize=(12, 10))

for ind, model_info in enumerate(models_to_run):
    ax_array[ind].plot(discriminator_iterations*np.arange(1, len(model_info["model"].losses["G"]) + 1), model_info["model"].losses["G"], label='Generator')
    ax_array[ind].plot(np.arange(1, len(model_info["model"].losses["D"]) + 1), model_info["model"].losses["D"], label='Discriminator')
    ax_array[ind].plot(np.arange(1, len(model_info["model"].losses["GP"]) + 1), model_info["model"].losses["GP"], label='Gradient Penalty')
    ax_array[ind].set_ylabel('Loss')
    ax_array[ind].set_xlabel('Iteration')
    ax_array[ind].set_title("Loss - " + model_info["name"])
    ax_array[ind].legend()
    ax_array[ind].grid()
plt.tight_layout()
plt.show()

In [None]:
for model_info in models_to_run:
    model_info["model"].load_state_dict(torch.load(model_info["filename"]))
    model_info["model"].to(device)
    model_info["model"].eval()
print('GAN models loaded succesfully')

In [None]:
np.random.seed(234)
torch.manual_seed(234)

n_samples = 5

for model_info in models_to_run:
    # sample GAN directly:
    model = model_info["model"]
    model.eval()
    gan_samples = model.generator.sample(num_samples=n_samples).view(n_samples, 28, 28).data.cpu().numpy()
    fig, ax_array = plt.subplots(1, n_samples, figsize=(12, 2))
    for i in range(gan_samples.shape[0]):
        ax_array[i].imshow(gan_samples[i], cmap='gray')
        ax_array[i].set_axis_off()
    plt.suptitle('Sample Examples From {}'.format(model_info["name"]))