In [31]:
import argparse
import numpy
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
import torchvision.utils as vutils
from torchvision import transforms, datasets
import random
import os

In [32]:
img_size = 256
batch_size = 32
channels = 3
latent_size = 128
beta_1 = 0.5
beta_2 = 0.999
num_epochs = 500
lr = 0.002

In [33]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [34]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def convlayer(n_input, n_output, k_size=4, stride=2, padding=0):
            block = [
                nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(n_output),
                nn.ReLU(inplace=True),
            ]
            return block

        self.model = nn.Sequential(
            *convlayer(latent_size, 1024, 4, 1, 0),
            *convlayer(1024, 512, 4, 2, 1),
            *convlayer(512, 256, 4, 2, 1),
            *convlayer(256, 128, 4, 2, 1),
            *convlayer(128, 64, 4, 2, 1),
            *convlayer(64, 32, 4, 2, 1),
            nn.ConvTranspose2d(32, channels, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self, z):
        z = z.view(-1, latent_size, 1, 1)
        img = self.model(z)
        return img

In [35]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, bn=False):
            block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
            if bn:
                block.append(nn.BatchNorm2d(n_output))
            block.append(nn.LeakyReLU(0.2, inplace=True))
            return block

        self.model = nn.Sequential(
            *convlayer(channels * 2, 32, 4, 2, 1),
            *convlayer(32, 64, 4, 2, 1),
            *convlayer(64, 128, 4, 2, 1, bn=True),
            *convlayer(128, 256, 4, 2, 1, bn=True),
            *convlayer(256, 512, 4, 2, 1, bn=True),
            *convlayer(512, 1024, 4, 2, 1, bn=True),
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False),  # FC with Conv.
        )

    def forward(self, imgs):
        critic_value = self.model(imgs)
        critic_value  = critic_value.view(imgs.size(0), -1)
        return critic_value

In [36]:
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
])
data = datasets.ImageFolder(root='../abstract_1/smaller_abstract', transform=transform)
def generate_random_sample():
    while True:
        random_indexes = numpy.random.choice(data.__len__(), size=batch_size * 2, replace=False)
        batch = [data[i][0] for i in random_indexes]
        yield torch.stack(batch, 0)
def mse_loss(input, target):
    return torch.sum((input - target)**2) / input.data.nelement()
random_sample = generate_random_sample()

In [43]:
def train_GAN():
    cuda = torch.cuda.is_available()
    cuda = False
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    gan_loss = mse_loss

    generator = Generator()
    discriminator = Discriminator()

    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta_1, beta_2))
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta_1, beta_2))

    # Loss record.
    g_losses = []
    d_losses = []
    epochs = []
    loss_legend = ['Discriminator', 'Generator']

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    noise_fixed = Variable(Tensor(25, latent_size).normal_(0, 1), requires_grad=False)

    for it in range(int(num_epochs)):
        print('Iter. {}'.format(it))

        batch = random_sample.__next__()

        imgs_real = Variable(batch.type(Tensor))
        imgs_real = torch.cat((imgs_real[0:batch_size, ...], imgs_real[batch_size:batch_size * 2, ...]), 1)
        real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False)

        noise = Variable(Tensor(batch_size * 2, latent_size).normal_(0, 1))
        imgs_fake = generator(noise)
        imgs_fake = torch.cat((imgs_fake[0:batch_size, ...], imgs_fake[batch_size:batch_size * 2, ...]), 1)

        # == Discriminator update == #
        optimizer_D.zero_grad()

        c_xr = discriminator(imgs_real)
        c_xf = discriminator(imgs_fake.detach())

        d_loss = gan_loss(c_xr, torch.mean(c_xf) + real) + gan_loss(c_xf, torch.mean(c_xr) - real)

        d_loss.backward()
        optimizer_D.step()

        # == Generator update == #
        batch = random_sample.__next__()

        imgs_real = Variable(batch.type(Tensor))
        imgs_real = torch.cat((imgs_real[0:batch_size, ...], imgs_real[batch_size:batch_size * 2, ...]), 1)

        noise = Variable(Tensor(batch_size * 2, latent_size).normal_(0, 1))
        imgs_fake = generator(noise)
        imgs_fake = torch.cat((imgs_fake[0:batch_size, ...], imgs_fake[batch_size:batch_size * 2, ...]), 1)

        c_xr = discriminator(imgs_real)
        c_xf = discriminator(imgs_fake)
        real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False)

        optimizer_G.zero_grad()

        g_loss = gan_loss(c_xf, torch.mean(c_xr) + real) + gan_loss(c_xr, torch.mean(c_xf) - real)

        g_loss.backward()
        optimizer_G.step()
        if it % 25 == 0:

                # Keep a record of losses for plotting.
            epochs.append(it)
            g_losses.append(g_loss.item())
            d_losses.append(d_loss.item())

                # Generate images for a given set of fixed noise
                # so we can track how the GAN learns.
            imgs_fake_fixed = generator(noise_fixed).detach().data
                #imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom.
            vutils.save_image(vutils.make_grid(imgs_fake_fixed[0], padding=2, normalize=True), 'prog3/'+str(it)+'image' + str(random.random()) + '.jpg')

In [None]:
train_GAN()

Iter. 0
Iter. 1
Iter. 2
Iter. 3
Iter. 4
Iter. 5
Iter. 6
Iter. 7
Iter. 8
Iter. 9
Iter. 10
Iter. 11
Iter. 12
Iter. 13
Iter. 14
Iter. 15


In [39]:
torch.cuda.empty_cache() 