In [1]:
import torch
import numpy as np
import torchvision

In [2]:
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset, Dataset

In [3]:
DATA = "data/"

In [4]:
%%capture
mnist_ds= torchvision.datasets.MNIST(DATA, download=True)

In [5]:
import torchvision.transforms as transforms

Make all desired image transformations here

In [6]:
transform = transforms.Compose([
                    transforms.ToTensor()
])

custom DS for loading tensors and labels

In [7]:
class MNISTDS(Dataset):
    def __init__(self, mnist_ds, transform):
        super().__init__()
        self.mnist_ds= mnist_ds
        self.transform = transform
    def __len__(self):
        return len(self.mnist_ds)
    def __getitem__(self, i):
        return self.transform(self.mnist_ds[i][0]), self.mnist_ds[i][1]
    

In [8]:
ds = MNISTDS(mnist_ds, transform)

Define the model

In [9]:
import torch.nn as nn

The DCGAN architecture is a bunch of stacked "transposed convolutions". If like me, you're wondering what the hell a tranposed convolution is, see [here](https://github.com/vdumoulin/conv_arithmetic) for some very helpful visualizations. The punchline is: it's an ordinary convolution, but where the *stride* is used to "inflate" the input image before feeding it to the conv filter, so that the outputs can end up being larger in spatial extent than the inputs.

I'm basically copying the [example pytorch implementation of DCGAN](https://github.com/pytorch/examples/blob/master/dcgan/main.py)

In [10]:
class Generator(nn.Module):
    def __init__(self, latent_size):
        """latent_size = size of the latent space"""
        super(Generator, self).__init__()
        
        self.latent_size = latent_size
        
        #spatial extent at each layer
        size = [4, 8, 16, 32]
        #kernel size
        self.kernel_size =4
        #(proportional to the) number of generator filters
        self.ngf = 64
        
        #takes a latent vector and outputs MNIST-sized image
        #input: (_, nz, 1, 1) latent vector
        self.upsample = nn.Sequential(
            
                                nn.ConvTranspose2d(self.latent_size, 4 * self.ngf, self.kernel_size,
                                                      stride=1,padding=0,bias=False), 
                                nn.BatchNorm2d( 4 * self.ngf), 
                                nn.ReLU(),
                                #spatial extent here is set by the kernel: (4,4)
                                
                                #by setting stride=2, we effectively double the output size (up to fiddling
                                #with the boundary conditions..)
                                # Weirdly, increasing the 'padding' arg actually decreases the amount of padding 
                                #that's applied to the input. the only reason padding is being used here is to
                                #keep the output shapes at nice multiples of two
                                nn.ConvTranspose2d(4 * self.ngf, 2 * self.ngf, self.kernel_size,
                                                      stride=2,padding=1, bias=False), 
                                nn.BatchNorm2d( 2 * self.ngf), 
                                nn.ReLU(),
                                
                                #( 8,8)
                                nn.ConvTranspose2d(2 * self.ngf, 1 * self.ngf, self.kernel_size,
                                                      stride=2,padding=1, bias=False), 
                                nn.BatchNorm2d( 1 * self.ngf), 
                                nn.ReLU(),
                                #(16,16)
                                #here I'm increasing the padding to bring the output size to (28,28)
                                #for MNIST
                                nn.ConvTranspose2d(self.ngf, 1, self.kernel_size,
                                                      stride=2,padding=3, bias=False), 
                                nn.Tanh(),
                                #(32,32)
                                
                                                                                            
                                )
        
    def forward(self, z):
        """Input: (_, latent_size) noise tensor
            Output: (_, 1, 32, 32) generated image tensor"""
        z = z.view(-1, self.latent_size, 1, 1)
        return self.upsample(z)

In [11]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        #scaling for the number of filters
        self.nf = 64
        # kernel size 
        self.kernel_size = 4
        #input (1,28,28)
        
        #this is more or less the generator stack run in reverse
        # a stride of 2 and padding of 1 causes the spatial extent to halve at each step
        self.main = nn.Sequential(
            nn.Conv2d(1, self.nf, self.kernel_size, stride=2,padding=3, bias=False),
            nn.LeakyReLU(.2),
            nn.Conv2d(self.nf, 2 * self.nf, self.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2 * self.nf),
            nn.LeakyReLU(.2),
            nn.Conv2d(2*self.nf, 4 * self.nf, self.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(4 * self.nf),
            nn.LeakyReLU(.2),
            nn.Conv2d(4*self.nf, 1 , self.kernel_size, stride=1, padding=0, bias=False),
            nn.Sigmoid()
            
        )
        
    def forward(self, x):
        """Input: (_, 1, 28, 28) image
            Output: (_, 1) classification tensor"""
        x = x.view(-1, 1, 28,28)
        return self.main(x).view(-1)

In [12]:
devname = "cuda:0" if torch.cuda.is_available() else "cpu"

In [13]:
device = torch.device(devname)

In [14]:
latent_size = 100

Initialize the Conv and BatchNorm layers. Apparently these values work well

In [15]:
def weight_init(module):
    classname = module.__class__.__name__
    if classname.find('Conv')!=-1:
        module.weight.data.normal_(0.0, .02)
    elif classname.find('BatchNorm') != -1:
        module.weight.data.normal_(1.0, .01)
        module.bias.data.fill_(0.0)

In [16]:
G = Generator(latent_size).to(device)
D = Discriminator().to(device)

In [17]:
G.apply(weight_init)
D.apply(weight_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)

In [18]:
batch_size = 64

In [19]:
dl = DataLoader(ds, batch_size=batch_size)

In [20]:
loss = nn.BCELoss()

In [21]:
real_label = 1
fake_label = 1- real_label

In [22]:
lr = 2e-4
beta=.5

In [23]:
optimizerD = torch.optim.Adam(D.parameters(), lr=lr,betas=(beta, .999))
optimizerG = torch.optim.Adam(G.parameters(), lr=lr,betas=(beta, .999))

In [24]:
epochs = 30

In [25]:
from torchvision.utils import save_image

In [26]:
disc_losses = []
gen_losses = []


In [None]:
savestep=0

In [None]:
for epoch in range(epochs):
    print("Starting epoch %d..." % epoch)
    for i, (x,y) in enumerate(dl):
        x = x.to(device)
        y = y.to(device)

        #size of current batch
        N = x.shape[0]
        ## update the discriminator
        disc_real = D(x)
        real_probs = torch.full((N,), real_label, device=device)
        disc_loss_real = loss(disc_real, real_probs)

        z = torch.randn(N, latent_size,1,1,device=device)
        #detach here, so gradients don't flow to the generator
        fake_outputs = G(z).detach()
        disc_fake = D(fake_outputs)
        fake_probs = torch.full((N,), fake_label, device=device)
        disc_loss_fake = loss(disc_fake, fake_probs)

        D.zero_grad()
        disc_loss = disc_loss_real+ disc_loss_fake
        disc_loss.backward()
        #update disc weights only
        optimizerD.step()

        ## update the generator
        z = torch.randn(N, latent_size, 1, 1, device=device)
        disc_fake = D(G(z))
        fake_labels = torch.full((N,), real_label, device=device)
        gen_loss = loss(disc_fake, fake_labels)

        G.zero_grad()
        gen_loss.backward()
        optimizerG.step()   
    
        if i %100 ==0:
            disc_losses.append(disc_loss.item())
            gen_losses.append(gen_loss.item())
        
        #save some examples of real and fake images
        save_image(x, "saved_images/real_epoch_%d.png"%savestep)
        save_image(fake_outputs, "saved_images/fake_epoch_%d.png"%savestep)

    #save the models
    torch.save(D.state_dict(), "saved_models/discriminator_epoch%d"%epoch)
    torch.save(G.state_dict(), "saved_models/generator_epoch%d"%epoch)

Starting epoch 0...


In [None]:
import numpy as np
np.save("saved_models/generator_loss", gen_losses)
np.save("saved_models/discriminator_loss", disc_losses)