In [12]:
import os
import time
import datetime
import torch
import torch.nn as nn
import torchvision.datasets as dsets
from torchvision import transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from model import Generator, Discriminator
import matplotlib.pyplot as plt

In [85]:
def tensor2var(x, grad=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=grad)

def var2tensor(x):
    return x.data.cpu()

def var2numpy(x):
    return x.data.cpu().numpy()

def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

In [145]:
batch_size = 64
imsize = 64
g_conv_dim = 64
d_conv_dim = 64
z_dim = 100
beta1 = 0.0
beta2 = 0.9
total_step = 1000000

options = []
options.append(transforms.CenterCrop(160))
options.append(transforms.Resize((imsize,imsize)))
options.append(transforms.ToTensor())
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
dataset = dsets.ImageFolder(os.getcwd(), transform=transforms.Compose(options))
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size,shuffle=True,
                                     num_workers=2,drop_last=True)

In [95]:
def train():
    # Initialize model
    G = Generator(batch_size, imsize, z_dim, g_conv_dim).cuda()
    D = Discriminator(batch_size, imsize, d_conv_dim).cuda()
    
    # Initialize optimizer with filter, lr and coefficients
    g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, G.parameters()), 0.0001, [beta1, beta2])
    d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), 0.0004, [beta1, beta2])
    data_iter = iter(loader)
    step_per_epoch = len(loader)
    start_time = time.time()
    
    # Fix a random latent input for Generator
    fixed_z = tensor2var(torch.randn(batch_size, z_dim))
    
    # Training, total_step as the number of total batches trained 
    for step in range(total_step):
        # ================== Train D ================== #
        D.train();G.train()
        try:
            real_images, _ = next(data_iter)
        except:
            data_iter = iter(loader)
            real_images, _ = next(data_iter)
        
        # Compute loss with real images
        # dr1, dr2, df1, df2, gf1, gf2 are attention scores
        real_images = tensor2var(real_images)
        d_out_real,dr1,dr2 = D(real_images)
        d_loss_real = - torch.mean(d_out_real)
        
        # apply Gumbel Softmax
        z = tensor2var(torch.randn(real_images.size(0), z_dim))
        fake_images,gf1,gf2 = G(z)
        d_out_fake,df1,df2 = D(fake_images)
        d_loss_fake = d_out_fake.mean()
        
        # Backward + Optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Compute gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images)
        interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)
        out,_,_ = D(interpolated)

        grad = torch.autograd.grad(outputs=out,
                                    inputs=interpolated,
                                    grad_outputs=torch.ones(out.size()).cuda(),
                                    retain_graph=True,
                                    create_graph=True,
                                    only_inputs=True)[0]

        grad = grad.view(grad.size(0), -1)
        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
        d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

        # Backward + Optimize
        d_loss = 10 * d_loss_gp
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================== Train G and gumbel ================== #
        # Create random noise
        z = tensor2var(torch.randn(real_images.size(0), z_dim))
        fake_images,_,_ = G(z)

        # Compute loss with fake images
        g_out_fake,_,_ = D(fake_images)  # batch x n
        g_loss_fake = - g_out_fake.mean()
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        g_loss_fake.backward()
        g_optimizer.step()


        # Print out log info
        if (step + 1) % 10 == 0:
            elapsed = time.time() - start_time
            elapsed = str(datetime.timedelta(seconds=elapsed))
            print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
                  " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
                  format(elapsed, step + 1, total_step, (step + 1),
                         total_step , d_loss_real.item(),
                         G.attn1.gamma.mean().item(), G.attn2.gamma.mean().item()))

        # Sample images
        if (step + 1) % 100 == 0:
            fake_images,_,_= G(fixed_z)
            save_image(denorm(fake_images.data),
                        os.path.join('./samples', '{}_fake.png'.format(step + 1)))
        
        # Save models
        if (step+1) % 100==0:
            torch.save(G.state_dict(),
                        os.path.join('./models', '{}_G.pth'.format(step + 1)))
            torch.save(D.state_dict(),
                        os.path.join('./models', '{}_D.pth'.format(step + 1)))

In [None]:
train()

### Test

In [156]:
import torch
import torchvision
import torchvision.datasets as dsets
from torchvision import transforms
batch_size = 64

options = []
options.append(transforms.ToTensor())
options.append(transforms.Normalize((0.1307,), (0.3081,)))
mnist = dsets.MNIST(root='./mnist_data', train=True, download=True, transform=transforms.Compose(options))
train_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True,
                                           num_workers=2, drop_last=True)
Iter = iter(train_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Processing...
Done!


In [157]:
x,y =next(Iter)
x.shape

torch.Size([64, 1, 28, 28])

In [158]:
conv_dim = 64

layer1 = []
layer1.append(nn.Conv2d(1, conv_dim, 3, 2, 1)) #(1->64, 28->14)
layer1.append(nn.LeakyReLU(0.1))
l1 = nn.Sequential(*layer1)


layer2 = []
curr_dim = conv_dim
layer2.append(nn.Conv2d(curr_dim, curr_dim * 2, 3, 2, 1))#(64->128, 14->7)
layer2.append(nn.LeakyReLU(0.1))
l2 = nn.Sequential(*layer2)


layer3 = []
curr_dim = curr_dim * 2
layer3.append(nn.Conv2d(curr_dim, curr_dim * 2, 3, 2, 1))#(128->256, 7->4)
layer3.append(nn.LeakyReLU(0.1))
l3 = nn.Sequential(*layer3)

curr_dim = curr_dim * 2 

In [159]:
out = l1(x)
out = l2(out)
out = l3(out)
out.shape

torch.Size([64, 256, 4, 4])

### ATTN

In [160]:
# out.shape = torch.Size([64, 256, 4, 4])
m_batchsize,C,width ,height = out.size()

# The effect of the 1 X 1 convolution is it just adds non-linearity.
query_conv = nn.Conv2d(in_channels = curr_dim , out_channels = curr_dim//8 , kernel_size= 1)
key_conv = nn.Conv2d(in_channels = curr_dim , out_channels = curr_dim//8 , kernel_size= 1)
value_conv = nn.Conv2d(in_channels = curr_dim , out_channels = curr_dim , kernel_size= 1)
softmax = nn.Softmax(dim=-1)
gamma = nn.Parameter(torch.zeros(1))

proj_query = query_conv(out).view(m_batchsize,-1,width*height).permute(0,2,1) # B * N * C, N = W*H
# torch.Size([64, 16, 32])

proj_key =  key_conv(out).view(m_batchsize,-1,width*height) # B * C * N, N = W*H
# torch.Size([64, 32, 16])

energy =  torch.bmm(proj_query,proj_key) # batch matrix-matrix product
# torch.Size([64, 16, 16])

attention = softmax(energy) # softmax to ensure for the output sums up to 1, as a weight
# torch.Size([64, 16, 16])

proj_value = value_conv(out).view(m_batchsize,-1,width*height) # B * C * N
# torch.Size([64, 256, 16])

att = torch.bmm(proj_value,attention.permute(0,2,1)) # batch matrix-matrix product
# torch.Size([64, 256, 16])
att = att.view(m_batchsize,C,width,height)
# torch.Size([64, 256, 4, 4])

out1 = gamma*att + out
# torch.Size([64, 256, 4, 4])

In [128]:
imsize = 64
options = []
options.append(transforms.CenterCrop(160))
options.append(transforms.Resize((imsize,imsize)))
options.append(transforms.ToTensor())
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
dataset = dsets.ImageFolder(os.getcwd(), transform=transforms.Compose(options))
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1000,shuffle=True,
                                     num_workers=2,drop_last=True)
Iter = iter(loader)
x,y =next(Iter)
x.shape

torch.Size([1000, 3, 64, 64])

In [129]:
conv_dim = 64

layer1 = []
layer1.append(nn.Conv2d(3, conv_dim, 4, 2, 1)) #(1->64, 28->14)
layer1.append(nn.LeakyReLU(0.1))
l1 = nn.Sequential(*layer1)


layer2 = []
curr_dim = conv_dim
layer2.append(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))#(64->128, 14->7)
layer2.append(nn.LeakyReLU(0.1))
l2 = nn.Sequential(*layer2)


layer3 = []
curr_dim = curr_dim * 2
layer3.append(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))#(128->256, 7->4)
layer3.append(nn.LeakyReLU(0.1))
l3 = nn.Sequential(*layer3)

curr_dim = curr_dim * 2 

out = l1(x)
out = l2(out)
out = l3(out)
out.shape

torch.Size([1000, 256, 8, 8])

In [130]:
# out.shape = torch.Size([64, 256, 4, 4])
m_batchsize,C,width ,height = out.size()

# The effect of the 1 X 1 convolution is it just adds non-linearity.
query_conv = nn.Conv2d(in_channels = curr_dim , out_channels = curr_dim//8 , kernel_size= 1)
key_conv = nn.Conv2d(in_channels = curr_dim , out_channels = curr_dim//8 , kernel_size= 1)
value_conv = nn.Conv2d(in_channels = curr_dim , out_channels = curr_dim , kernel_size= 1)
softmax = nn.Softmax(dim=-1)
gamma = nn.Parameter(torch.zeros(1))

proj_query = query_conv(out).view(m_batchsize,-1,width*height).permute(0,2,1) # B * N * C, N = W*H
# torch.Size([64, 16, 32])

proj_key =  key_conv(out).view(m_batchsize,-1,width*height) # B * C * N, N = W*H
# torch.Size([64, 32, 16])

energy =  torch.bmm(proj_query,proj_key) # batch matrix-matrix product
# torch.Size([64, 16, 16])

attention = softmax(energy) # softmax to ensure for the output sums up to 1, as a weight
# torch.Size([64, 16, 16])

proj_value = value_conv(out).view(m_batchsize,-1,width*height) # B * C * N
# torch.Size([64, 256, 16])

att = torch.bmm(proj_value,attention.permute(0,2,1)) # batch matrix-matrix product
# torch.Size([64, 256, 16])
att = att.view(m_batchsize,C,width,height)
# torch.Size([64, 256, 4, 4])

out1 = gamma*att + out
# torch.Size([64, 256, 4, 4])

In [141]:
G = Generator(64, 64, 100, 64)
filter(lambda p: p.requires_grad, G.parameters())

<filter at 0x123c5e8d0>

In [191]:
from spectral import SpectralNorm
batch_size = 64
imsize=28
z_dim=100 
conv_dim=64

# Layer 1 turn 100 dims -> 64 dims, 1 -> 3
layer1 = []
layer1.append(SpectralNorm(nn.ConvTranspose2d(in_channels = z_dim, out_channels = conv_dim, kernel_size = 4)))
layer1.append(nn.BatchNorm2d(conv_dim))
layer1.append(nn.ReLU())
curr_dim = conv_dim
l1 = nn.Sequential(*layer1)
        
# Layer 2 turn 64 dims -> 32 dims, 3 -> 7
layer2 = []
layer2.append(SpectralNorm(nn.ConvTranspose2d(in_channels = curr_dim, out_channels = int(curr_dim / 2), 
                                                      kernel_size = 4, stride = 2, padding = 1)))
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer2.append(nn.ReLU())
curr_dim = int(curr_dim / 2)
l2 = nn.Sequential(*layer2)
        
# Layer 3 turn 32 dims -> 16 dims, 3 -> 14
layer3 = []
layer3.append(SpectralNorm(nn.ConvTranspose2d(in_channels = curr_dim, out_channels = int(curr_dim / 2), 
                                                      kernel_size = 4, stride = 2, padding = 1)))
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer3.append(nn.ReLU())
curr_dim = int(curr_dim / 2)
l3 = nn.Sequential(*layer3)

# Layer 5 (Attn) turn 16 dims -> 16 dims
# self.attn1 = Self_Attn(curr_dim, 'relu')
        
# Layer 6 turn 16 dims -> 3 dims, 14 -> 28
last = []
last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
last.append(nn.Tanh())
last = nn.Sequential(*last)

In [193]:
z = tensor2var(torch.randn(batch_size, z_dim))
z = z.view(z.size(0), z.size(1), 1, 1) # torch.Size([64, 100, 1, 1])
out=l1(z) # torch.Size([64, 64, 3, 3])
out=l2(out) # torch.Size([64, 32, 7, 7])
#out=l3(out) # torch.Size([64, 16, 14, 14])
#out=last(out) # torch.Size([64, 3, 28, 28])
out.shape

torch.Size([64, 32, 8, 8])