In [7]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [103]:
batch_size = 32

transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data = datasets.FashionMNIST('~/.pytorch/FMNIST_data', train=True, download=True,
                            transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True,
                                           batch_size=batch_size)

In [3]:
class Generator(nn.Module):
    def __init__(self, nz, nfeats, nchannels):
        super(Generator, self).__init__()

        # input is Z, going into a convolution
        self.conv1 = nn.ConvTranspose2d(nz, nfeats * 8, 4, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(nfeats * 8)

        # state size. (nfeats*8) x 4 x 4
        self.conv2 = nn.ConvTranspose2d(nfeats * 8, nfeats * 4, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(nfeats * 4)
            
        # state size. (nfeats*4) x 8 x 8
        self.conv3 = nn.ConvTranspose2d(nfeats * 4, nfeats * 2, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(nfeats * 2)
        
        # state size. (nfeats*2) x 16 x 16
        self.conv4 = nn.ConvTranspose2d(nfeats * 2, nfeats, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(nfeats)
        
        # state size. (nfeats) x 32 x 32
        self.conv5 = nn.ConvTranspose2d(nfeats, nchannels, 3, 1, 1, bias=False)
        
        # state size. (nchannels) x 64 x 64

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = F.leaky_relu(self.bn4(self.conv4(x)))
        x = torch.tanh(self.conv5(x))
        
        return x

In [4]:
class Discriminator(nn.Module):
    def __init__(self, nchannels, nfeats):
        super(Discriminator, self).__init__()

        # input is (nc) x 32 x 32
        self.conv1 = nn.Conv2d(nchannels, nfeats, 4, 2, 1, bias=False)
        
        # state size. (ndf) x 16 x 16
        self.conv2 = nn.Conv2d(nfeats, nfeats * 2, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(nfeats * 2)
        # state size. (ndf*2) x 8 x 8
        self.conv3 = nn.Conv2d(nfeats * 2, nfeats * 4, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(nfeats * 4)
        # state size. (ndf*4) x 4 x 4
        self.conv4 = nn.Conv2d(nfeats * 4, nfeats * 8, 3, 1, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(nfeats * 8)
        # state size. (ndf*8) x 4 x 4
        self.conv5 = nn.Conv2d(nfeats * 8, 1, 4, 1, 0, bias=False)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2)
        x = torch.sigmoid(self.conv5(x))
        
        return x.view(-1, 1)

In [104]:
lr = 0.0003
beta1 = 0.5

netG = Generator(100, 8, 1)
netD = Discriminator(1, 8)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [105]:
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nz = 100
fixed_noise = torch.randn(30, nz, 1, 1, device=device)
real_label = 0.9
fake_label = 0
batch_size = train_loader.batch_size

step = 0
for epoch in range(epochs):
    for ii, (real_images, train_labels) in enumerate(train_loader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_images = real_images.to(device)
        batch_size = real_images.size(0)
        labels = torch.full((batch_size, 1), real_label, device=device)

        output = netD(real_images)
        errD_real = criterion(output, labels)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        labels.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, labels)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        labels.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, labels)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if step % 50 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (epoch, epochs, ii, len(train_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            
            valid_image = netG(fixed_noise)
            np_img = valid_image.to("cpu").detach().numpy()
            fig = img_grid(np_img, fig=fig)
            #plt.show()
            fig.savefig(f"saved_images/image_{step:05d}.png")
        step += 1

[0/10][0/1875] Loss_D: 1.4151 Loss_G: 0.6918 D(x): 0.5360 D(G(z)): 0.5357 / 0.5032
[0/10][50/1875] Loss_D: 0.9679 Loss_G: 1.3736 D(x): 0.5719 D(G(z)): 0.2873 / 0.2354


KeyboardInterrupt: 

In [88]:
def img_grid(images, figsize=(6, 5), w_pad=0, h_pad=0, fig=None, show=False):
    if fig is None:
        fig, _ = plt.subplots(nrows=5, ncols=6, figsize=figsize)
    
    for img, ax in zip(images[:30], fig.axes):
        ax.imshow(np.clip((img.squeeze()+1)/2, 0, 1))

        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
    fig.tight_layout(w_pad=w_pad, h_pad=h_pad)
    return fig