In [8]:
import os
import copy
import time
import pickle
import numpy as np
import pandas as pd
from datetime import datetime
import time
import argparse
import random
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset, random_split,RandomSampler
from torchvision import datasets, transforms
import torchvision
from torchvision.datasets import ImageFolder
from torch.autograd import Variable
from torchvision.datasets import MNIST, EMNIST
import torch.nn.functional as F
from matplotlib.pyplot import subplots
from torchvision.utils import save_image
import torch.optim as optim
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
bs = 128
n_epoch = 50
dim=100
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu = 1
device = 1

In [9]:
def test_inference(net, testloader):
    """ Returns the test accuracy and loss.
    """
    net.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    criterion = nn.CrossEntropyLoss()
    criterion.cuda(gpu)
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(testloader):
            images, labels = images.cuda(gpu), labels.cuda(gpu)
            
            # Inference
            outputs = net(images).squeeze()
            batch_loss = criterion(outputs, labels)
            loss += copy.deepcopy(batch_loss.item())

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)
    accuracy = correct/total
    return accuracy, loss

class generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, 1, 0)
        self.deconv1_bn = nn.BatchNorm2d(d*8)
        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*4)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d*2)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(d)
        self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        # x = F.relu(self.deconv1(input))
        x = F.relu(self.deconv1_bn(self.deconv1(input)))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = F.tanh(self.deconv5(x))

        return x

class discriminator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8, 5, 4, 1, 0)# 1-6
        #self.linear = nn.Linear(16384, 6)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        #print(x.size())
        x = self.conv5(x)
        #print(x.size())
        #x = F.sigmoid(self.conv5(x))
        #x = self.linear(x)
        return x

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [10]:

writer2 = SummaryWriter(os.path.join('../celeimg', 'cls3_testzoo'))
gpu1 = 1
device1 = 1
img_size = 64
isCrop = False
if isCrop:
    transform = transforms.Compose([
        transforms.Scale(108),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

data_dir = '../gender'          # this path depends on your computer
dset = datasets.ImageFolder(data_dir, transform)
#train_size = int(0.8 * len(dset))
#test_size = len(dset)- train_size
#train_set, test_set = random_split(dset, [train_size, test_size])
train_loader = torch.utils.data.DataLoader(dset, batch_size=bs,shuffle=True, drop_last=True)
#test_loader = torch.utils.data.DataLoader(test_set, batch_size=100,shuffle=False)

lr = 0.0002
G2 = generator(128)
D2 = discriminator(128)
G2.weight_init(mean=0.0, std=0.02)
D2.weight_init(mean=0.0, std=0.02)
G2.cuda(gpu1)
D2.cuda(gpu1)

#criterion = nn.BCELoss().cuda(gpu)
criterion = nn.CrossEntropyLoss().cuda(gpu1)
# Adam optimizer
G_optimizer2 = optim.Adam(G2.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer2 = optim.Adam(D2.parameters(), lr=lr, betas=(0.5, 0.999))




D2.train()
G2.train()
for e in range(0, 800):
    for real_x, real_y in train_loader:
        real_x, real_y = real_x.to(gpu1), real_y.to(gpu1)
        G2.zero_grad()
        z = Variable(torch.randn(bs, 100)).view(-1, 100, 1, 1).to(gpu1)
        fake_x = G2(z)
        target_y = copy.deepcopy(real_y).fill_(0)
        
        grad_est = torch.zeros_like(fake_x).to(gpu1)
        m = 20
        epsilon = 0.1
        N = fake_x.size(0)
        C = fake_x.size(1)
        S = fake_x.size(2)
        d = S**2 * C
        fake_img_tanh =fake_x#torch.tanh(fake_img)
        lossG_target = criterion(D2(fake_img_tanh).squeeze(), target_y)
        for i in range(m):
            u = torch.randn(fake_x.size()).cuda(gpu1)
            u_flat = u.view([N, -1])
            u_norm = u / torch.norm(u_flat, dim=1).view([-1, 1, 1, 1])
            x_mod_pre = fake_x + (epsilon * u_norm)
            x_mod_pre_tanh = x_mod_pre #torch.tanh(x_mod_pre)
            Tout = D2(x_mod_pre_tanh).squeeze()
        
            lossG_target_mod = criterion(Tout, target_y)
            grad_est += (
                (d / m) * (lossG_target_mod - lossG_target) / epsilon
                ).view([-1, 1, 1, 1]) * u_norm
        
        grad_est /= N
        g_loss = lossG_target_mod.mean()
        fake_x.backward(grad_est)
        G_optimizer2.step()
    
        D2.zero_grad()
        fake_x2 = G2(z)
        f_out = D2(fake_x2).squeeze()
        d_fake_loss = criterion(f_out, copy.deepcopy(real_y).fill_(2))
        r_out = D2(real_x).squeeze()
        d_real_loss = criterion(r_out, real_y)
        d_loss = d_fake_loss + d_real_loss
        d_loss.backward()
        D_optimizer2.step()
    if e % 1 == 0:
        print(f"g_loss: {g_loss.item()}, d_loss: {d_loss.item()}")
    with torch.no_grad():
        test_z = Variable(torch.randn(50, 100).view(-1, 100, 1, 1).to(device1))
        generated = G2(test_z)
        
        out0grid = torchvision.utils.make_grid(generated, nrow=50)
        writer2.add_image('images', out0grid, e)
        

g_loss: 13.510198593139648, d_loss: 0.055531296879053116
g_loss: 13.542549133300781, d_loss: 0.08645309507846832
g_loss: 13.635666847229004, d_loss: 0.043263014405965805
g_loss: 19.18427085876465, d_loss: 0.04111462086439133
g_loss: 18.022476196289062, d_loss: 0.06549662351608276
g_loss: 18.800886154174805, d_loss: 0.0463365763425827
g_loss: 31.526561737060547, d_loss: 0.02266242355108261
g_loss: 36.52996826171875, d_loss: 0.013080516830086708
g_loss: 20.804784774780273, d_loss: 0.015114218927919865
g_loss: 34.2892951965332, d_loss: 0.046089306473731995
g_loss: 25.4948787689209, d_loss: 0.013570869341492653
g_loss: 51.12068557739258, d_loss: 0.009643634781241417
g_loss: 21.865283966064453, d_loss: 0.013577502220869064
g_loss: 24.053010940551758, d_loss: 0.008960557170212269
g_loss: 30.558578491210938, d_loss: 0.0022932521533221006
g_loss: 51.90040969848633, d_loss: 0.010763954371213913
g_loss: 24.694761276245117, d_loss: 0.00936624500900507
g_loss: 22.159046173095703, d_loss: 0.0151261

KeyboardInterrupt: 