In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 32
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [7]:
class ACGAN_Generator(nn.Module):
    def __init__(self):
        super(ACGAN_Generator, self).__init__()
        self.emb = nn.Embedding(10, 100)
        self.fc = nn.Linear(100, 128 * 8 ** 2)
        self.main = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        x = torch.mul(self.emb(labels), noise)
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 8, 8)
        x = self.main(x)
        return x

class ACGAN_Discriminator(nn.Module):
    def __init__(self):
        super(ACGAN_Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),     
        )

        self.adv_layer = nn.Sequential(nn.Linear(128 * 2 ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * 2 ** 2, 10), nn.Softmax())

    def forward(self, img):
        x = self.main(img)
        x = x.view(x.shape[0], -1)
        validity = self.adv_layer(x)
        label = self.aux_layer(x)
        return validity, label


print("Initializing generator and discriminator (ACGAN)...")
acgan_generator = ACGAN_Generator()
acgan_discriminator = ACGAN_Discriminator()
acgan_generator.to(device)
acgan_discriminator.to(device)
print()

Initializing generator and discriminator (ACGAN)...



In [11]:
learning_rate = 0.0002
epochs = 50


def train(generator, discriminator, train_dataloader):
    source_criterion = nn.BCELoss()
    class_criterion = nn.NLLLoss()
    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_acgan_real/'): 
        os.makedirs('train_generated_images_acgan_real')
    if not os.path.exists('train_generated_images_acgan_fake/'): 
        os.makedirs('train_generated_images_acgan_fake')
        
    inception_score_file = open("IS_acgan.csv", "w")
    inception_score_file.write('epoch, IS \n')

    for epoch in range(epochs): 
        for images, labels in train_dataloader:
            batch_size = images.shape[0]
            real_images = Variable(images.type(torch.cuda.FloatTensor)).to(device)
            real_labels = Variable(labels.type(torch.cuda.LongTensor)).to(device)

            # adversarial ground truth
            fake = torch.zeros(batch_size).to(device)
            valid = torch.ones(batch_size).to(device)

            ### train generator
            optim_generator.zero_grad()
            z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
            generated_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))

            # generate image batch
            generated_images = generator(z, generated_labels)

            # compute generator loss, optimize generator
            validity, predicted_label = discriminator(generated_images)
            gen_loss = 0.5 * (source_criterion(validity, valid.unsqueeze(1)) + class_criterion(predicted_label, generated_labels))
            gen_loss.backward()
            optim_generator.step()

            ### train discriminator
            optim_discriminator.zero_grad()

            # compute real images loss
            real_pred, real_aux = discriminator(real_images)
            disc_loss_real = 0.5 * (source_criterion(real_pred, valid.unsqueeze(1)) + class_criterion(real_aux, real_labels))

            # compute fake images loss
            fake_pred, fake_aux = discriminator(generated_images.detach())
            disc_loss_fake = 0.5 * (source_criterion(fake_pred, fake.unsqueeze(1)) + class_criterion(fake_aux, generated_labels))

            # compute overall discriminator loss, optimize discriminator
            disc_loss = 0.5 * (disc_loss_real + disc_loss_fake)
            disc_loss.backward()
            optim_discriminator.step()
            
        # compute inception score and samples every epoch
        z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
        generated_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))
        samples = generator(z, generated_labels)

        # normalize to [0, 1]
        samples = samples.add(1.0).mul(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        IS, IS_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(IS, 3)))

        # samples = samples[:10].data.cpu()
        # grid = utils.make_grid(samples, nrow = 5)
        # utils.save_image(samples, 'train_generated_images_acgan_fake/epoch_{}.png'.format(str(epoch)))
        # utils.save_image(real_images, 'train_generated_images_acgan_real/epoch_{}.png'.format(str(epoch)))
        inception_score_file.write(str(epoch) + ', ' + str(round(IS, 3)) + '\n')

    inception_score_file.close()

In [12]:
# train ACGAN
print("TRAINING ACGAN MODEL...")
train(acgan_generator, acgan_discriminator, train_dataloader)

# save ACGAN to file
torch.save(acgan_generator.state_dict(), 'ACGAN_generator.pkl')
torch.save(acgan_discriminator.state_dict(), 'ACGAN_discriminator.pkl')

TRAINING ACGAN MODEL...
epoch: 0, inception score: 1.333
epoch: 1, inception score: 1.349
epoch: 2, inception score: 1.425
epoch: 3, inception score: 1.436
epoch: 4, inception score: 1.462
epoch: 5, inception score: 1.642
epoch: 6, inception score: 1.503
epoch: 7, inception score: 1.499
epoch: 8, inception score: 1.513
epoch: 9, inception score: 1.597
epoch: 10, inception score: 1.41
epoch: 11, inception score: 1.491
epoch: 12, inception score: 1.643
epoch: 13, inception score: 1.533
epoch: 14, inception score: 1.561
epoch: 15, inception score: 1.511
epoch: 16, inception score: 1.587
epoch: 17, inception score: 1.646
epoch: 18, inception score: 1.697
epoch: 19, inception score: 1.768
epoch: 20, inception score: 1.622
epoch: 21, inception score: 1.608
epoch: 22, inception score: 1.602
epoch: 23, inception score: 1.53
epoch: 24, inception score: 1.571
epoch: 25, inception score: 1.626
epoch: 26, inception score: 1.609
epoch: 27, inception score: 1.717
epoch: 28, inception score: 1.574
ep

In [14]:
#Load generator and discriminator model from model file
acgan_generator.load_state_dict(torch.load('ACGAN_generator.pkl'))
acgan_discriminator.load_state_dict(torch.load('ACGAN_discriminator.pkl'))

#Get 10 samples
z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
generated_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))
samples = acgan_generator(z, generated_labels)
samples = samples[:10]
samples = samples.add(1.0).mul(0.5)
samples = samples.data.cpu()

#Save 10 samples as immage in 2X10 grid
grid = utils.make_grid(samples, nrow=5)
utils.save_image(grid, 'acgan_generated_images.png')