In [None]:
%matplotlib inline

import argparse

parser = argparse.ArgumentParser("PCGAN")

# Folder to output result samples
parser.add_argument('--result_dir', type=str, default='celeba_result')

# Training Parameters
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--nepoch', type=int, default=15)
parser.add_argument('--lr_d', type=float, default=0.0002)
parser.add_argument('--lr_g', type=float, default=0.0002)
Betas = (0.5, 0.99) # adam optimizer beta1, beta2

# Model parameters
parser.add_argument('--gan_type', type=str, default='pcgan')
parser.add_argument('--nz', type=int, default=100) # number of noise dimension
parser.add_argument('--nc', type=int, default=3) # number of result channel
parser.add_argument('--nfeature', type=int, default=40)

config, _ = parser.parse_known_args()


# CHOOSE GAN TYPE
# 'scgan' for standard conditional GAN
# 'pcgan' for partially conditioned GAN
# config.gan_type = 'scgan'
config.gan_type = 'pcgan'

# FOLDERS CONTAINING THE DATA
config.dataset_dir = 'data_faces'
config.condition_file = '/content/gdrive/My Drive/list_attr_celeba.txt'

# Prepping

In [None]:
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import numpy as np
from datetime import datetime
cudnn.benchmark = True

#set manual seed to a constant get a consistent output
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# make output dir
if not os.path.exists(config.result_dir):
        os.mkdir(config.result_dir)

Random Seed:  5397


In [None]:
class ImageFeatureFolder(dset.ImageFolder):
    def __init__(self, image_root, landmark_file, transform):
        super(ImageFeatureFolder, self).__init__(root=image_root, transform=transform)
        with open(landmark_file, 'r') as f:
            data = f.read()
        data = data.strip().split('\n')
        self.attrs = torch.FloatTensor([list(map(float, line.split()[1:])) for line in data[2:]])
        
    def __getitem__(self, index):
        img, _ = super().__getitem__(index)               
        return img, self.attrs[index]

In [None]:
# GENERATE DATASET and DATALOADER
dataset = ImageFeatureFolder(config.dataset_dir, config.condition_file, transform=transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
dataloader = data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True,num_workers = 2)

In [None]:
# AUX FUNCTIONS

# Generate soft true/synthetic image labels
def make_target(tipo,batch_size):
    target = 0.1*torch.rand((batch_size,1),device = device)
    wrong_inds = torch.randint(0,batch_size,(batch_size//20,1),device = device)
    if tipo == 'synt':
        target = target + 0.9
        for k in wrong_inds:
            target[k] = target[k] - 0.9
    elif tipo == 'real':
        for k in wrong_inds:
            target[k] = target[k] + 0.9
    else:
        print('ERROR: WRONG TYPE')
    return target

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# Nets

In [None]:
# Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.FE = nn.ConvTranspose2d(config.nfeature, 256, 4, 1, 0, bias=False)
        self.lay0 = nn.ConvTranspose2d(config.nz, 256, 4, 1, 0, bias=False)
        self.main = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, config.nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
    
    def forward(self, x, attr):
        batch_size = x.size(0)
        attr = self.FE(attr.view(batch_size,config.nfeature,1,1))
        x = self.lay0(x.view(batch_size,config.nz,1,1))
        x = torch.cat([x, attr], 1)
        return self.main(x)

    def netF(self,y):
        y = self.FE(attr.view(batch_size,config.nfeature,1,1))
        return y

In [None]:
# standard conditional GAN discriminator network
class Discriminator_scgan(nn.Module):
    def __init__(self):
        super(Discriminator_scgan, self).__init__()
        self.feature_input = nn.Linear(config.nfeature, 64 * 64)
        self.main = nn.Sequential(
            nn.Conv2d(config.nc + 1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )
    
    def forward(self, x, attr):
        attr = self.feature_input(attr).view(-1, 1, 64, 64)
        x = torch.cat([x, attr], 1)
        return self.main(x).view(-1, 1)

In [None]:
# Partially conditioned GAN discriminator network
class Discriminator_pcgan(nn.Module):
    def __init__(self):
        super(Discriminator_pcgan, self).__init__()
        self.feature_input = nn.Linear(256*4*4, 64 * 64)
        self.main = nn.Sequential(
            nn.Conv2d(config.nc + 1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x, attr):
        batch_size = x.size(0)
        attr = attr.view(batch_size,256*4*4)
        attr = self.feature_input(attr).view(-1, 1, 64, 64)
        x = torch.cat([x, attr], 1)
        return self.main(x).view(-1, 1)

# Training

In [None]:
# Generate fixed sampling vectors
for n in range(10):
    disp_z = torch.empty(64,config.nz, device=device)
    disp_y = torch.empty(64,config.nfeature, device=device)
for k in range(4):
    aux_z = torch.randn(1,config.nz)
    for i in range(16):
        disp_z[i+16*k,:] = aux_z
for k in range(4):
    for j in range(8):
        disp_y[j+16*k,:] = dataset[j+4][1]
        disp_y[j+16*k+8,:] = dataset[j+4][1]*torch.from_numpy(np.random.binomial(1,0.7,(1,config.nfeature)))

import torch.optim as optim
loss = nn.MSELoss().to(device)

In [None]:
# CHOOSE NET
if config.gan_type == 'scgan':    
    netD = Discriminator_scgan().to(device)
    config.no_miss_p = 1
elif config.gan_type == 'pcgan':  
    netD = Discriminator_pcgan().to(device)
    config.no_miss_p = 0.85
netG = Generator().to(device)
netG.apply(weights_init)
netD.apply(weights_init)
optim_d = optim.Adam(netD.parameters(),lr = config.lr_d, betas = Betas)
optim_g = optim.Adam(netG.parameters(),lr = config.lr_g, betas = Betas)

for epoch in range(config.nepoch):
    running_g_loss = torch.tensor([0.0],device = device)
    running_d_loss = torch.tensor([0.0],device = device)
    epoch_d_acc = torch.tensor([0.0],device = device)

    for i, (data, attr) in enumerate(dataloader, 0):
        
        batch_size = data.size(0)

        # Train DISCRIMINATOR
        netD.zero_grad()

        noise = Variable(torch.FloatTensor(batch_size, config.nz, 1, 1).to(device))
        label_real = Variable(torch.FloatTensor(batch_size, 1).fill_(1).to(device))
        label_fake = Variable(torch.FloatTensor(batch_size, 1).fill_(0).to(device))

        label_real.data.resize(batch_size, 1).fill_(1)
        label_fake.data.resize(batch_size, 1).fill_(0)
        noise.data.resize_(batch_size, config.nz, 1, 1).normal_(0, 1)
        
        attr = Variable(attr.to(device))
        real = Variable(data.to(device))
        
        fake = netG(noise, attr)

        if config.gan_type == 'pcgan':
            turnoff = np.random.binomial(1,config.no_miss_p,(batch_size,config.nfeature))
            turnoff = torch.from_numpy(turnoff).to(device)
            attr = attr*turnoff
            with torch.no_grad():
                attr = netG.netF(attr)

        d_real = netD(real, attr)         
        d_fake = netD(fake.detach(), attr) # not update generator
        
        d_loss = loss(d_real, label_real) + loss(d_fake, label_fake) # real label
        d_loss.backward()
        optim_d.step()
        running_d_loss += d_loss

        # train GENERATOR
        netG.zero_grad()
        d_fake = netD(fake, attr)
        g_loss = loss(d_fake, label_real) # trick the fake into being real
        g_loss.backward()
        optim_g.step()
        running_g_loss += g_loss
        
    print('[%d/%d] Loss_D: %.2f Loss_G: %.2f' % (epoch, config.nepoch, running_d_loss.item(), running_g_loss.item()))
    print('saving the output')
    # torch.save(netG.state_dict(), config.result_dir+'/'+config.gan_type+'netG.pth')
    # torch.save(netD.state_dict(), config.result_dir+'/'+config.gan_type+'netD.pth')
    with torch.no_grad():
        fake = netG(disp_z,disp_y)
        vutils.save_image(fake.detach(),config.result_dir+'/'+config.gan_type+'samples_e_%03d.png' % (epoch),normalize=True,nrow = 8)
    
print('Finished Training')



[0/15] Loss_D: 2371.24 Loss_G: 5473.91
saving the output
[1/15] Loss_D: 1200.01 Loss_G: 2553.60
saving the output
[2/15] Loss_D: 999.71 Loss_G: 2481.35
saving the output
[3/15] Loss_D: 970.67 Loss_G: 2490.55
saving the output
[4/15] Loss_D: 887.10 Loss_G: 2559.71
saving the output
[5/15] Loss_D: 813.21 Loss_G: 2653.67
saving the output


KeyboardInterrupt: ignored