In [1]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision import transforms
from torchvision.utils import save_image
import torchvision

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from skimage.metrics import peak_signal_noise_ratio
import torch.nn as nn
import torch.nn.functional as F
import torch

import sys
cifar_dir = '../../models/cifar100_dcgan_grayscale/'
sys.path.insert(0,cifar_dir) # So we can import point_density_functions from parent directory
from dcgan import Discriminator, Generator

sys.path.insert(0,'..') 
from tools_optim import *

%load_ext autoreload
%autoreload 2

## DCGAN Grayscale/CIFAR100

In [2]:
os.makedirs("images", exist_ok=True)

class Args(object):
    def __init__(self):
        self.n_epochs = 200
        self.batch_size = 64
        self.lr = .002
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.latent_dim = 100
        self.img_size = 32
        self.channels = 3
        self.sample_interval = 400
        self.batch_size = 32
        self.dataroot = 'data/cifar10'
        self.image_size = 32
        self.grayscale = False

opt=Args()

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
num_gpu = 1 if torch.cuda.is_available() else 0

In [None]:
# load the models
from dcgan import Generator

if opt.grayscale:
    generator = Generator(ngpu=1,nc=1).eval()
    # load weights
    generator.load_state_dict(torch.load(cifar_dir+'/models/netG_epoch_999.pth'))
else:
    generator = Generator(ngpu=1,nc=3).eval()
    # load weights
    generator.load_state_dict(torch.load(cifar_dir+'/models/netG_rgb_epoch_999.pth'))
    
if torch.cuda.is_available():
    generator = generator.cuda()

# Freeze the model
for param in generator.parameters():
    param.requires_grad = False
generator.eval()

### Test 1: Find image that is in latent space

In [None]:
# Initialize z, the target latent vector 
z = Variable(torch.randn(1, opt.latent_dim))
# Initialize x, the initial latent vector
x_keep = Variable(Tensor(np.random.normal(0, 1, (1, opt.latent_dim))),requires_grad=False)

# Generate im, the target image
im = generator(z.unsqueeze(2).unsqueeze(2).to('cuda'))
plot_img(im,opt.grayscale)

In [None]:
# Generate im_first, the initial guess image
with torch.no_grad():
    im_first = generator(x_keep.unsqueeze(2).unsqueeze(2).to('cuda'))
    plot_img(im_first,opt.grayscale)

In [None]:
# after optimizing
with torch.no_grad():
    plot_img(im_g,False)

#### Optimize once with grad loop

In [None]:
# Optimize once, given optimizer parameters
x = x_keep.clone().requires_grad_()

In [None]:
optimizer = torch.optim.Adam([x],lr=0.01)
loss_list, psnr_list, im_g,x_out = grad_loop(im,x.unsqueeze(2).unsqueeze(2).to('cuda'),
                            generator,optimizer,scheduler=None,n_epochs=100000,epsilon = .01)

In [None]:
# Latent vector norm - to confirm latent norm isn't growing too much
print("Latent norm initial: ",torch.norm(x_keep-z.to('cuda')))
print("Latent norm final: ",torch.norm(x_out[:,:,0,0]-z.to('cuda')))

#### Plot loss curves for different optimizers

In [None]:
op_list = [torch.optim.SGD,torch.optim.Adam]
lr_list = [1e-5,1e-4,1e-3,1e-2]
loss_dict = convergence_loop(op_list,lr_list,generator,x.unsqueeze(2).unsqueeze(2),im)

In [None]:
plot_loss_curve(loss_dict,title='Learning curves for DCGAN on image in training set')

In [None]:
# Compare loss across optimzers at a given epoch
for k in loss_dict.keys():
    print(k,": ",loss_dict[k][5])

### TEST 2: Find im not in the latent space

In [3]:
# Load the data

if opt.grayscale:
    dataset = datasets.CIFAR100(root=opt.dataroot, download=True,
                       transform=transforms.Compose([
                           transforms.Grayscale(),
                           transforms.Resize(opt.image_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5), (0.5)),
                       ]))
else: 
    dataset = datasets.CIFAR100(root=opt.dataroot, download=True,
                       transform=transforms.Compose([
                           # transforms.Grayscale(),
                           transforms.Resize(opt.image_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
                       ]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                         shuffle=True, num_workers=1)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar10/cifar-100-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

KeyboardInterrupt: 

In [None]:
# Initialize x, the initial latent vector
x_keep = Variable(Tensor(np.random.normal(0, 1, (1, opt.latent_dim))),requires_grad=False)

# Generate im, the target image
im,label = next(iter(dataloader))
im = im.to('cuda')
plot_img(im,opt.grayscale)

In [None]:
# Generate im_first, the initial guess image
with torch.no_grad():
    im_first = generator(x_keep.unsqueeze(2).unsqueeze(2).to('cuda'))
    plot_img(im_first,opt.grayscale)

In [None]:
# after optimizing
with torch.no_grad():
    plot_img(im_g,opt.grayscale)

In [None]:
x_keep.shape

#### Plot loss curves for different optimizers

In [None]:
op_list = [torch.optim.SGD,torch.optim.Adam,torch.optim.AdamW,torch.optim.ASGD]
lr_list = [1e-4,1e-3,1e-2]
loss_dict,psnr_dict,im_dict = convergence_loop(op_list,lr_list,generator,x_keep.unsqueeze(2).unsqueeze(2),im,n_epochs=10000,epsilon=.1)

In [None]:
plot_loss_curve(loss_dict,title='Learning curves for DCGAN on image in training set')

In [None]:
def plot_psnr_curve(loss_dict,title,keys=None):
    # Given a loss_dict where the optimizer name and LR are the key, plots the learning curves
    if keys == None:
        keys = loss_dict.keys()
    plt.figure(figsize=[15,9])
    for k in keys:
        plt.plot(np.arange(0,100*len(loss_dict[k]),100),loss_dict[k],label=str(k))
    plt.legend(loc='below',fontsize=18)
    plt.ylabel("PSNR",fontsize=18)
    plt.xlabel("Epoch",fontsize=18)
    plt.title(title,fontsize=18)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    
plot_psnr_curve(psnr_dict,"PSNR curves for different optimizers on CIFAR100 grayscale images.")

In [None]:
plt.subplot(1,2,1)
plot_img(im,opt.grayscale)
plt.title("Original Image")
plt.subplot(1,2,2)
plot_img(im_dict['Adam\'>_0.01'].detach().cpu(),True)
plt.title("Optimizer Image")

In [None]:
plt.imshow(im_dict['.SGD\'>_0.01'].detach().cpu()[0,0],cmap='gray')

In [None]:
# Compare loss across optimzers at a given epoch
for k in loss_dict.keys():
    print(k,": {:2.4f}".format(loss_dict[k][5]))

#### Optimize once with grad loop

In [None]:
# Optimize once, given optimizer parameters
x = x_keep.clone().requires_grad_()
optimizer = torch.optim.Adam([x],lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5000,10000,20000,40000], gamma=0.5)

In [None]:
loss_list,psnr_list, im_g, x_out = grad_loop(im,x.unsqueeze(2).unsqueeze(2).to('cuda'),
                            generator,optimizer,scheduler=scheduler,n_epochs=10000,epsilon = .01)

In [None]:
plot_img(im_g,opt.grayscale)

## Training DCGAN

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [None]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        "data/cifar10",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)
# MNIST
# # Configure data loader
# os.makedirs("data/mnist", exist_ok=True)
# dataloader = torch.utils.data.DataLoader(
#     datasets.MNIST(
#         "data/mnist",
#         train=True,
#         download=True,
#         transform=transforms.Compose(
#             [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
#         ),
#     ),
#     batch_size=opt.batch_size,
#     shuffle=True,
# )
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)


In [None]:
imgs.shape

In [None]:
imgs = imgs*0.5 + 0.5
plt.imshow(imgs[0,:,:,:].detach().to('cpu').permute(1,2,0))

In [None]:
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
im_g = generator(z)
im_g = im_g*0.5 + 0.5
plt.imshow(im_g[0,:,:,:].detach().to('cpu').permute(1,2,0))

In [None]:
z

In [None]:
# Saving model
MODEL_PATH = "models/dcgan/generator_cifar10.pth"
torch.save(generator.state_dict(), MODEL_PATH)