In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 32
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [3]:
class DCGAN_Generator(nn.Module):
    def __init__(self): 
        super(DCGAN_Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(num_features=1024),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=256, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh())

    def forward(self, x):
        return self.net(x)

class DCGAN_Discriminator(nn.Module):
    def __init__(self):
        super(DCGAN_Discriminator, self).__init__()
        self.net = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0),
        nn.Sigmoid())

    def forward(self, x):
        return self.net(x)
    
print("Initializing generator and discriminator (DCGAN)...")
dcgan_generator = DCGAN_Generator()
dcgan_discriminator = DCGAN_Discriminator()
dcgan_generator.to(device)
dcgan_discriminator.to(device)


Initializing generator and discriminator (DCGAN)...


DCGAN_Discriminator(
  (net): Sequential(
    (0): Conv2d(3, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1))
    (9): Sigmoid()
  )
)

In [4]:
learning_rate = 0.0002
epochs = 50


def train(generator, discriminator, train_dataloader):
    loss = nn.BCELoss()  #Binary cross entropy loss
    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_dcgan/'): 
        os.makedirs('train_generated_images_dcgan')
        
    inception_score_file = open("IS_dcgan.csv", "w")
    inception_score_file.write('epoch, IS \n')

    for epoch in range(epochs): 
        for real_images, _ in train_dataloader:
            real_images = real_images.to(device)
            z = Variable(torch.randn(batch_size, 100, 1, 1)).to(device)
            real_labels = torch.ones(batch_size).to(device)
            fake_labels = torch.zeros(batch_size).to(device)

            ### train discriminator
            # compute loss using real images
            preds = discriminator(real_images)
            disc_loss_real = loss(preds.flatten(), real_labels)

            # compute loss using fake images
            fake_images = generator(z)
            preds = discriminator(fake_images)
            disc_loss_fake = loss(preds.flatten(), fake_labels)

            # optimize discriminator
            disc_loss = disc_loss_real + disc_loss_fake
            discriminator.zero_grad()
            disc_loss.backward()
            optim_discriminator.step()
            
            ### train generator
            # compute loss with fake images
            z = Variable(torch.randn(batch_size, 100, 1, 1)).to(device)
            fake_images = generator(z)
            preds = discriminator(fake_images)
            gen_loss = loss(preds.flatten(), real_labels)

            # optimize generator 
            generator.zero_grad()
            gen_loss.backward()
            optim_generator.step()

        # compute inception score and samples every epoch
        z = Variable(torch.randn(800, 100, 1, 1)).to(device)
        samples = generator(z)

        # normalize to [0, 1]
        samples = samples.add(1.0).mul(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        IS, IS_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(IS, 3)))

        # samples = samples[:10].data.cpu()
        # grid = utils.make_grid(samples, nrow = 5)
        # utils.save_image(grid, 'train_generated_images_dcgan/epoch_{}.png'.format(str(epoch)))
        
        inception_score_file.write(str(epoch) + ', ' + str(round(IS, 3)) + '\n')

    inception_score_file.close()

In [5]:
# train DCGAN
print("TRAINING DCGAN MODEL...")
train(dcgan_generator, dcgan_discriminator, train_dataloader)

# save DCGAN to file
torch.save(dcgan_generator.state_dict(), 'DCGAN_generator.pkl')
torch.save(dcgan_discriminator.state_dict(), 'DCGAN_discriminator.pkl')

TRAINING DCGAN MODEL...
epoch: 0, inception score: 2.718
epoch: 1, inception score: 2.639
epoch: 2, inception score: 3.146
epoch: 3, inception score: 3.529
epoch: 4, inception score: 3.566
epoch: 5, inception score: 3.555
epoch: 6, inception score: 3.965
epoch: 7, inception score: 3.906
epoch: 8, inception score: 4.132
epoch: 9, inception score: 4.415
epoch: 10, inception score: 4.715
epoch: 11, inception score: 4.889
epoch: 12, inception score: 4.712
epoch: 13, inception score: 4.864
epoch: 14, inception score: 4.952
epoch: 15, inception score: 4.851
epoch: 16, inception score: 4.719
epoch: 17, inception score: 5.065
epoch: 18, inception score: 4.985
epoch: 19, inception score: 4.858
epoch: 20, inception score: 4.834
epoch: 21, inception score: 5.029
epoch: 22, inception score: 5.142
epoch: 23, inception score: 4.842
epoch: 24, inception score: 5.21
epoch: 25, inception score: 5.063
epoch: 26, inception score: 5.168
epoch: 27, inception score: 5.07
epoch: 28, inception score: 5.046
ep

In [6]:
#Load generator and discriminator model from model file
dcgan_generator.load_state_dict(torch.load('DCGAN_generator.pkl'))
dcgan_discriminator.load_state_dict(torch.load('DCGAN_discriminator.pkl'))

#Get 10 samples
z = torch.randn(batch_size, 100, 1, 1).to(device)
samples = dcgan_generator(z)
samples = samples[:10]
samples = samples.add(1.0).mul(0.5)
samples = samples.data.cpu()

#Save 10 samples as immage in 2X10 grid
grid = utils.make_grid(samples, nrow=5)
utils.save_image(grid, 'dcgan_generated_images.png')
