In [None]:
import random

import torch
import numpy as np


def same_seeds(seed):
    # Python built-in random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(2001)

In [None]:
import os
import glob

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from qqdm.notebook import qqdm

In [None]:
facedataset = torchvision.datasets.ImageFolder(root=r"archive\img_align_celeba",
                                            transform=transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]))
print(f'We have {len(facedataset)} images')

In [None]:
dataset = facedataset

images = [dataset[i][0] for i in range(16)]
grid_img = torchvision.utils.make_grid(images, nrow=4)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
images = [(dataset[i][0]+1)/2 for i in range(16)]
grid_img = torchvision.utils.make_grid(images, nrow=4)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    """
    Input shape: (N, in_dim)
    Output shape: (N, 3, 64, 64)
    """
    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()
        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                   padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU()
        )
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),
            dconv_bn_relu(dim * 4, dim * 2),
            dconv_bn_relu(dim * 2, dim),
            nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(weights_init)

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2_5(y)
        return y


class Discriminator(nn.Module):
    """
    Input shape: (N, 3, 64, 64)
    Output shape: (N, )
    """
    def __init__(self, in_dim, dim=64):
        super(Discriminator, self).__init__()

        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2),
            )
            
        """ Medium: Remove the last sigmoid layer for WGAN. """
        self.ls = nn.Sequential(
            nn.Conv2d(in_dim, dim, 5, 2, 2), 
            nn.LeakyReLU(0.2),
            conv_bn_lrelu(dim, dim * 2),
            conv_bn_lrelu(dim * 2, dim * 4),
            conv_bn_lrelu(dim * 4, dim * 8),
            nn.Conv2d(dim * 8, 1, 4),
            #nn.Sigmoid(), 
        )
        self.apply(weights_init)
        
    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y

In [None]:
# Training hyperparameters
batch_size = 64
z_dim = 100
lr = 1e-4

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n_epoch = 50
n_critic = 5
clip_value = 0.01


# Model
G = Generator(in_dim=z_dim).to(device)
D = Discriminator(3).to(device)
############################################
G.load_state_dict(torch.load('G.pth',map_location=device))
D.load_state_dict(torch.load('D.pth',map_location=device))
##########################################
G.train()
D.train()


opt_D = torch.optim.RMSprop(D.parameters(), lr=lr)
opt_G = torch.optim.RMSprop(G.parameters(), lr=lr)


# DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
steps = 0
ImgList = []
DLosses = []
GLosses = []
for e, epoch in enumerate(range(n_epoch)):
    progress_bar = qqdm(dataloader)
    for i, (data,_) in enumerate(progress_bar):
        imgs = data
        imgs = imgs.to(device)

        bs = imgs.size(0)

        z = Variable(torch.randn(bs, z_dim)).to(device)
        r_imgs = Variable(imgs).to(device)
        f_imgs = G(z)

        loss_D = -torch.mean(D(r_imgs)) + torch.mean(D(f_imgs))
       

        D.zero_grad()
        loss_D.backward()

        opt_D.step()

  
        for p in D.parameters():
           p.data.clamp_(-clip_value, clip_value)

  
        if steps % n_critic == 0:

            z = Variable(torch.randn(bs, z_dim)).to(device)
            f_imgs = G(z)

        
            loss_G = -torch.mean(D(f_imgs))

          
            G.zero_grad()
            loss_G.backward()

   
            opt_G.step()

        steps += 1
        DLosses.append(round(loss_D.item(), 4))
        GLosses.append(round(loss_G.item(), 4))

        progress_bar.set_infos({
            'Loss_D': round(loss_D.item(), 4),
            'Loss_G': round(loss_G.item(), 4),
            'Epoch': e+1,
            'Step': steps,
        })
        if steps % 5 == 0:
            G.eval()
            with torch.no_grad():
                f_imgs_sample = (G(torch.randn(bs, z_dim).to(device)).data + 1)/2.0
                ImgList.append(f_imgs_sample.cpu())
                grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=8)
                plt.figure(figsize=(10,10))
                plt.imshow(grid_img.permute(1, 2, 0))
                plt.show()
            G.train()

    if (e+1) % 1 == 0 or e == 0:
        # Save the checkpoints.
        torch.save(G.state_dict(), os.path.join("", 'G.pth'))
        torch.save(D.state_dict(), os.path.join("", 'D.pth'))
