In [8]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 32
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [11]:
class WGAN_Generator(nn.Module):
    def __init__(self):
        super(WGAN_Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(100,128),
            nn.ReLU(),
            nn.Linear(128,256),
            nn.BatchNorm1d(256, 0.8),
            nn.ReLU(),
            nn.Linear(256,512),
            nn.BatchNorm1d(512, 0.8),
            nn.ReLU(),
            nn.Linear(512,1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.ReLU(),
            nn.Linear(1024, 3*32*32),
            nn.Tanh()   
        )
    def forward(self, z):
        x = self.net(z)
        x = x.view(x.shape[0], *(3,32,32))
        return x

class WGAN_Discriminator(nn.Module):
    def __init__(self):
        super(WGAN_Discriminator, self).__init__()
        self.model= nn.Sequential(
            nn.Linear(3*32*32, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256,1),
            
        )
    
    def forward(self,x):
        x = x.view(x.shape[0],-1)
        validity = self.model(x)
        return validity

print("Initializing generator and Discriminator (WGAN)...")
wgan_generator = WGAN_Generator()
wgan_discriminator = WGAN_Discriminator()
wgan_generator.to(device)
wgan_discriminator.to(device)

Initializing generator and Discriminator (WGAN)...


WGAN_Discriminator(
  (model): Sequential(
    (0): Linear(in_features=3072, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [17]:
learning_rate=5e-4
epochs=50
batch_size=32

weight_clip = 0.01
nCritic = 4

def train(generator, discriminator, train_dataloader):

    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_wgan/'): 
        os.makedirs('train_generated_images_wgan')
        
    inception_score_file = open("IS_wgan.csv", "w")
    inception_score_file.write('epoch, IS \n')
    
    for epoch in range(epochs): 
        for i, (images, _) in enumerate(train_dataloader):

            real_images = Variable(images.type(torch.cuda.FloatTensor))

            ### train discriminator
            optim_discriminator.zero_grad()
            z = Variable(torch.Tensor(np.random.normal(0, 1, (images.shape[0], 100)))).to(device)
            fake_images = generator(z).detach()
            disc_loss = -torch.mean(discriminator(real_images)) + torch.mean(discriminator(fake_images))
            disc_loss.backward()
            optim_discriminator.step()

            # apply weight clipping
            for p in discriminator.parameters():
                p.data.clamp_(-weight_clip, weight_clip)

            # Train generator every nCritic batch
            if i % nCritic == 0:
                optim_generator.zero_grad()
                fake_images = generator(z)
                gen_loss = -torch.mean(discriminator(fake_images))
                gen_loss.backward()
                optim_generator.step()
            
         # compute inception score and samples every epoch
        z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (images.shape[0], 100)))).to(device)
        samples = generator(z)

        # normalize to [0, 1]
        samples = samples.add(1.0).mul(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        IS, IS_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(IS, 3)))

        # samples = samples[:10].data.cpu()
        # grid = utils.make_grid(samples, nrow = 5)
        # utils.save_image(grid, 'train_generated_images_dcgan/epoch_{}.png'.format(str(epoch)))
        
        inception_score_file.write(str(epoch) + ', ' + str(round(IS, 3)) + '\n')

    inception_score_file.close()
                

In [18]:
# train WGAN
print("TRAINING WGAN MODEL...")
train(wgan_generator, wgan_discriminator, train_dataloader)

# save WGAN to file
torch.save(wgan_generator.state_dict(), 'WGAN_generator.pkl')
torch.save(wgan_discriminator.state_dict(), 'WGAN_discriminator.pkl')

TRAINING WGAN MODEL...
epoch: 0, inception score: 1.663
epoch: 1, inception score: 1.577
epoch: 2, inception score: 1.726
epoch: 3, inception score: 1.595
epoch: 4, inception score: 1.543
epoch: 5, inception score: 1.707
epoch: 6, inception score: 1.529
epoch: 7, inception score: 1.739
epoch: 8, inception score: 1.688
epoch: 9, inception score: 1.441
epoch: 10, inception score: 1.682
epoch: 11, inception score: 1.593
epoch: 12, inception score: 1.642
epoch: 13, inception score: 1.571
epoch: 14, inception score: 1.824
epoch: 15, inception score: 1.735
epoch: 16, inception score: 1.555
epoch: 17, inception score: 1.598
epoch: 18, inception score: 1.734
epoch: 19, inception score: 1.644
epoch: 20, inception score: 1.667
epoch: 21, inception score: 1.635
epoch: 22, inception score: 1.72
epoch: 23, inception score: 1.691
epoch: 24, inception score: 1.795
epoch: 25, inception score: 1.607
epoch: 26, inception score: 1.545
epoch: 27, inception score: 1.679
epoch: 28, inception score: 1.72
epo

In [21]:
#Load generator and discriminator model from model file
wgan_generator.load_state_dict(torch.load('WGAN_generator.pkl'))
wgan_discriminator.load_state_dict(torch.load('WGAN_discriminator.pkl'))

#Get 10 samples
z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
samples = wgan_generator(z)
samples = samples[:10]
samples = samples.add(1.0).mul(0.5)
samples = samples.data.cpu()

#Save 10 samples as immage in 2X10 grid
grid = utils.make_grid(samples, nrow=5)
utils.save_image(grid, 'wgan_generated_images.png')