In [None]:
import torch
import torch.nn as nn
from torchvision import transforms,datasets
from torchvision.utils import make_grid as grid_creation
from torch.utils.data import DataLoader

from matplotlib import pyplot as plt
import numpy as np

import os

In [None]:
IMAGE_DIM = 64
BATCH_SIZE = 64
NOISE_DIM = 100
NUM_EPOCHS = 400
ALPHA = 5e-5
CLAMP_VALUE = 0.01
NUM_EPOCH_PER_GENERATOR_UPDATE = 10
NUM_EPOCH_PER_GENERATOR_TEST = 40

In [None]:
PATH = "./"

In [None]:
transform = transforms.Compose([
            transforms.Resize((IMAGE_DIM,IMAGE_DIM)),
            transforms.CenterCrop(IMAGE_DIM),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
trainset = datasets.CIFAR10(PATH, train=True, download=True,
                                                       transform=transform)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

#testset = datasets.CIFAR10(root=PATH, train=False,
#                                       download=True, transform=transform)
#testloader = DataLoader(testset, batch_size=BATCH_SIZE,shuffle=False, num_workers=2)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(NOISE_DIM, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = self.model(input)
        return out

In [None]:
class Critic_Discriminator(nn.Module):
    def __init__(self):
        super(Critic_Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, IMAGE_DIM, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(IMAGE_DIM, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(512, 1, 4, 1, 0),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = torch.flatten(self.model(input))
        return out

In [None]:
generator = Generator()
if os.path.exists(PATH+'generator_cifar10_model.pth'):
    generator.load_state_dict(torch.load(PATH+'generator_cifar10_model.pth'))
    print("loaded last model!")

critic_discriminator = Critic_Discriminator()

generator.to(device)
critic_discriminator.to(device)

In [None]:
g_optim = torch.optim.RMSprop(generator.parameters(), lr=ALPHA)
c_optim = torch.optim.RMSprop(critic_discriminator.parameters(), lr=ALPHA)

In [None]:
test_noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)

generator.train()
critic_discriminator.train()

g_loss_lst = []
c_loss_lst = []
output_noise_images = []

for epoch_num in range(1,NUM_EPOCHS+1):
    g_epoch_error = 0.0
    c_epoch_error = 0.0
    for batch_idx, data in enumerate(trainloader):
        batch_imgs, _ = data
        real_data = batch_imgs.to(device)
        train_noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake_data = generator(train_noise)

        c_optim.zero_grad()
        total_critic_error = critic_discriminator(fake_data).mean() - critic_discriminator(real_data).mean()
        total_critic_error.backward()
        c_optim.step()
        for p in critic_discriminator.parameters(): 
            p.data.clamp_(-CLAMP_VALUE, CLAMP_VALUE) #LIPCHITZ FUNCTION F REQUIREMENT
        c_epoch_error += -1*total_critic_error #make it positive
        
        if (batch_idx+1)%NUM_EPOCH_PER_GENERATOR_UPDATE==0:
            fake_data = generator(torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device))
            g_optim.zero_grad()
            error_generator = -critic_discriminator(fake_data).mean()
            error_generator.backward()
            g_optim.step()
            g_epoch_error += error_generator

    print('Epoch {}: Critic_loss: {:.3f} Generator_loss: {:.3f}'.format(epoch_num, c_epoch_error, g_epoch_error))
    g_loss_lst.append(g_epoch_error.item())
    c_loss_lst.append(c_epoch_error.item())

    if epoch_num%NUM_EPOCH_PER_GENERATOR_TEST==0:
        fake_img = generator(test_noise).cpu().detach()
        output_noise_images.append(grid_creation(fake_img))


torch.save(generator.state_dict(), PATH+'generator_cifar10_model.pth')
print('Saved generator_cifar10_model')

In [None]:
plt.plot(c_loss_lst, label='Discriminator Losses')
plt.plot(g_loss_lst, label='Generator Losses')
plt.legend()
plt.savefig(PATH+'total_loss.png')

In [None]:
from PIL import Image 
for idx,img in enumerate(output_noise_images):
    permuted_img = (img.permute(2,1,0) * 255).cpu().numpy().astype(np.uint8)
    plt.imshow(permuted_img)
    pil_img = transforms.ToPILImage()(img)
    pil_img.save(PATH+'image_'+str(idx)+'.jpg')