In [1]:
import os
import time
import datetime
import torch
import torch.nn as nn
import torchvision.datasets as dsets
from torchvision import transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from model import Generator, Discriminator
import matplotlib.pyplot as plt

In [2]:
def tensor2var(x, grad=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=grad)

def var2tensor(x):
    return x.data.cpu()

def var2numpy(x):
    return x.data.cpu().numpy()

def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

In [9]:
batch_size = 64
imsize = 32
g_conv_dim = 64
d_conv_dim = 64
z_dim = 100
beta1 = 0.0
beta2 = 0.9
total_step = 1000000

options = []
options.append(transforms.CenterCrop(160))
options.append(transforms.Resize((imsize,imsize)))
options.append(transforms.ToTensor())
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
dataset = dsets.ImageFolder(os.getcwd(), transform=transforms.Compose(options))
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size,shuffle=True,
                                     num_workers=2,drop_last=True)

In [14]:
def train():
    # Initialize model
    G = Generator(batch_size, imsize, z_dim, g_conv_dim)
    D = Discriminator(batch_size, imsize, d_conv_dim)
    
    # Initialize optimizer with filter, lr and coefficients
    g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, G.parameters()), 0.0001, [beta1, beta2])
    d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), 0.0004, [beta1, beta2])
    data_iter = iter(loader)
    step_per_epoch = len(loader)
    start_time = time.time()
    
    # Fix a random latent input for Generator
    fixed_z = tensor2var(torch.randn(batch_size, z_dim))
    
    # Training, total_step as the number of total batches trained 
    for step in range(total_step):
        # ================== Train D ================== #
        D.train();G.train()
        try:
            real_images, _ = next(data_iter)
        except:
            data_iter = iter(loader)
            real_images, _ = next(data_iter)
        
        # Compute loss with real images
        # dr1, dr2, df1, df2, gf1, gf2 are attention scores
        d_out_real,dr1,dr2 = D(real_images)
        d_loss_real = - torch.mean(d_out_real)
        
        # apply Gumbel Softmax
        z = tensor2var(torch.randn(real_images.size(0), z_dim))
        fake_images,gf1,gf2 = G(z)
        d_out_fake,df1,df2 = D(fake_images)
        d_loss_fake = d_out_fake.mean()
        
        # Backward + Optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Compute gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images)
        interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)
        out,_,_ = D(interpolated)

        grad = torch.autograd.grad(outputs=out,
                                    inputs=interpolated,
                                    grad_outputs=torch.ones(out.size()),
                                    retain_graph=True,
                                    create_graph=True,
                                    only_inputs=True)[0]

        grad = grad.view(grad.size(0), -1)
        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
        d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

        # Backward + Optimize
        d_loss = 10 * d_loss_gp
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================== Train G and gumbel ================== #
        # Create random noise
        z = tensor2var(torch.randn(real_images.size(0), z_dim))
        fake_images,_,_ = G(z)

        # Compute loss with fake images
        g_out_fake,_,_ = D(fake_images)  # batch x n
        g_loss_fake = - g_out_fake.mean()
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        g_loss_fake.backward()
        g_optimizer.step()


        # Print out log info
        if (step + 1) % 10 == 0:
            elapsed = time.time() - start_time
            elapsed = str(datetime.timedelta(seconds=elapsed))
            print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
                  " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
                  format(elapsed, step + 1, total_step, (step + 1),
                         total_step , d_loss_real.item(),
                         G.attn1.gamma.mean().item(), G.attn2.gamma.mean().item()))

        # Sample images
        if (step + 1) % 100 == 0:
            fake_images,_,_= G(fixed_z)
            save_image(denorm(fake_images.data),
                        os.path.join('./samples', '{}_fake.png'.format(step + 1)))
        
        # Save models
        if (step+1) % 100==0:
            torch.save(G.state_dict(),
                        os.path.join('./models', '{}_G.pth'.format(step + 1)))
            torch.save(D.state_dict(),
                        os.path.join('./models', '{}_D.pth'.format(step + 1)))

In [16]:
train()

Elapsed [0:00:13.233412], G_step [10/1000000], D_step[10/1000000], d_out_real: -9.8039,  ave_gamma_l3: -0.0014, ave_gamma_l4: 0.0000


KeyboardInterrupt: 

### Test

In [18]:
# Initialize model
G = Generator(batch_size, imsize, z_dim, g_conv_dim)
D = Discriminator(batch_size, imsize, d_conv_dim)
    
# Initialize optimizer with filter, lr and coefficients
g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, G.parameters()), 0.0001, [beta1, beta2])
d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), 0.0004, [beta1, beta2])
data_iter = iter(loader)
step_per_epoch = len(loader)
start_time = time.time()
    
# Fix a random latent input for Generator
fixed_z = tensor2var(torch.randn(batch_size, z_dim))

D.train();G.train()
real_images, _ = next(data_iter)

In [19]:
d_out_real,dr1,dr2 = D(real_images)

In [26]:
d_loss_real = - torch.mean(d_out_real)

In [27]:
z = tensor2var(torch.randn(real_images.size(0), z_dim))
z.shape

torch.Size([64, 100])

In [32]:
fake_images,gf1,gf2 = G(z)
d_out_fake,df1,df2 = D(fake_images)
d_loss_fake = d_out_fake.mean()
d_loss = d_loss_real + d_loss_fake
d_loss

tensor(-0.0237, grad_fn=<AddBackward0>)

In [33]:
d_loss.backward()
d_loss

tensor(-0.0237, grad_fn=<AddBackward0>)

In [34]:
d_optimizer.step()
d_optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: [0.0, 0.9]
    eps: 1e-08
    lr: 0.0004
    weight_decay: 0
)

In [35]:
alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images)
interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)
out,_,_ = D(interpolated)

In [38]:
grad = torch.autograd.grad(outputs=out,
                                    inputs=interpolated,
                                    grad_outputs=torch.ones(out.size()),
                                    retain_graph=True,
                                    create_graph=True,
                                    only_inputs=True)[0]
grad.shape

torch.Size([64, 3, 32, 32])

In [None]:
# Training, total_step as the number of total batches trained 
for step in range(total_step):
    # ================== Train D ================== #
    D.train();G.train()
    
    # Get a new batch of data
    try:
        real_images, _ = next(data_iter)
    except:
        data_iter = iter(loader)
        real_images, _ = next(data_iter)
        
    # Compute loss with real images
    # dr1, dr2, df1, df2, gf1, gf2 are attention scores
    d_out_real,dr1,dr2 = D(real_images)
    d_loss_real = - torch.mean(d_out_real)
        
    # apply Gumbel Softmax
    z = tensor2var(torch.randn(real_images.size(0), z_dim))
    fake_images,gf1,gf2 = G(z)
    d_out_fake,df1,df2 = D(fake_images)
    d_loss_fake = d_out_fake.mean()
        
    # Backward + Optimize
    d_loss = d_loss_real + d_loss_fake
    d_optimizer.zero_grad(); g_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()
        
    # Compute gradient penalty
    alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images)
    interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)
    out,_,_ = D(interpolated)

    grad = torch.autograd.grad(outputs=out,
                                    inputs=interpolated,
                                    grad_outputs=torch.ones(out.size()),
                                    retain_graph=True,
                                    create_graph=True,
                                    only_inputs=True)[0]

    grad = grad.view(grad.size(0), -1)
    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

    # Backward + Optimize
    d_loss = 10 * d_loss_gp
    d_optimizer.zero_grad(); g_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()
        
    # ================== Train G and gumbel ================== #
    # Create random noise
    z = tensor2var(torch.randn(real_images.size(0), z_dim))
    fake_images,_,_ = G(z)

    # Compute loss with fake images
    g_out_fake,_,_ = D(fake_images)  # batch x n
    g_loss_fake = - g_out_fake.mean()
    d_optimizer.zero_grad(); g_optimizer.zero_grad()
    g_loss_fake.backward()
    g_optimizer.step()