In [49]:
%pylab inline

import matplotlib.pyplot as plt

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [50]:
import torch
import torch.nn as nn

In [51]:
from torchvision import datasets
from torchvision.transforms import ToTensor, Grayscale, Compose, Normalize

In [52]:
import tqdm

In [53]:
figsize(3,3)

In [54]:
DEVICE='cuda:0'
# DEVICE='cpu'

In [55]:
BATCH_SIZE=16

In [56]:
%%capture

train_data = datasets.MNIST(
    root='data',
    train=True,
    transform=Compose([
        ToTensor(),
        Normalize((0.5), (0.5))
    ]),
    download=True
)

In [57]:
train_data_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    num_workers=8,
    shuffle=False
)

In [58]:
print(len(train_data_loader))

3750


In [59]:
IMAGE_DIM_X = None

In [60]:
for i, data in enumerate(train_data_loader):
    if i == 3:
        img_tensor = data[0][0]
        shape = list(img_tensor.shape)
        print(shape)
        IMAGE_DIM_X = shape[1]
        img = data[0][0].reshape(shape[1], shape[2])
        print(img.shape)
        plt.imshow(img.numpy())
        break

KeyboardInterrupt: 

In [None]:
discriminator_conv_out_channels = 32

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        kernel_size = 5
        padding = 0
        stride = 1

        output_size = int(((IMAGE_DIM_X - kernel_size + 2*padding) / stride) + 1)

        self.convs = nn.Sequential(
            nn.Conv2d(
                11,
                kernel_size=kernel_size,
                out_channels=discriminator_conv_out_channels,
                padding=padding,
                stride=stride
            ),
            nn.BatchNorm2d(discriminator_conv_out_channels),
            nn.ReLU()
        )
        self.expected_linear_input_size = discriminator_conv_out_channels * output_size * output_size

        self.classifier = nn.Sequential(
            nn.Linear(self.expected_linear_input_size, 1),  # 1 for single scalar value (confidence)
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.convs(x)
        x = x.reshape(-1, self.expected_linear_input_size)  # flatten and preserve batch dim
        x = self.classifier(x)
        return x

In [None]:
GENERATOR_INPUT_DIM = 25

In [None]:
generator_deconv_input_dim = 10
linear_out_size = discriminator_conv_out_channels * generator_deconv_input_dim * generator_deconv_input_dim


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.linear = nn.Sequential(
            nn.Linear(GENERATOR_INPUT_DIM, linear_out_size),
            nn.BatchNorm1d(linear_out_size),
            nn.ReLU(),
        )

        self.deconvs = nn.Sequential(
            nn.ConvTranspose2d(
                discriminator_conv_out_channels,
                discriminator_conv_out_channels,
                kernel_size=7
            ),
            nn.BatchNorm2d(discriminator_conv_out_channels),
            nn.ReLU(),

            nn.ConvTranspose2d(discriminator_conv_out_channels, 1, kernel_size=13),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(
            BATCH_SIZE,
            discriminator_conv_out_channels,
            generator_deconv_input_dim,
            generator_deconv_input_dim
        )
        x = self.deconvs(x)
        return x

In [None]:
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

In [None]:
bce = nn.BCELoss().to(DEVICE)

discriminator_opt = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))
generator_opt = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))

In [None]:
test_generator = Generator().to(DEVICE)

test_noise = torch.randn(BATCH_SIZE, GENERATOR_INPUT_DIM).to(DEVICE)
print('noise shape', test_noise.shape)

test_out = test_generator(test_noise)
print('test out shape', test_out.shape)
print('test out 0 shape', test_out[0].shape)

plt.imshow(test_out[0].reshape(test_out.shape[2], test_out.shape[3]).cpu().detach())

In [None]:
def create_noise(offset=10):  # subtract this for labels
    n = torch.randn(BATCH_SIZE, GENERATOR_INPUT_DIM - offset).to(DEVICE)
    return n

In [None]:
def test_generate_from_noise():
    noise = create_noise(0)
    out = generator(noise)
    return out[0].reshape(out.shape[2], out.shape[3]).cpu().detach()  # [1, 28, 28] -> [28, 28]

In [None]:
plt.imshow(test_generate_from_noise())

In [None]:
def to_one_hot(labels, num_classes):
    one_hot = torch.zeros(len(labels), num_classes)
    one_hot.scatter_(1, labels.unsqueeze(1), 1)
    return one_hot.to(DEVICE)

In [None]:
EPOCHS = 11

for epoch in range(EPOCHS):
    for i, data in enumerate(train_data_loader):
        real_image = data[0].to(DEVICE)
        one_hot = to_one_hot(data[1], 10)

        # train discriminator

        # learn real image
        discriminator.zero_grad()
        oh_broadcasted = one_hot.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 28, 28)
        image_with_classes = torch.cat((real_image, oh_broadcasted), dim=1)
        discriminator_out_real = discriminator(image_with_classes)
        discriminator_loss_real = bce(
            discriminator_out_real,
            torch.ones(  # every real image is a 1
                discriminator_out_real.shape[0],
                discriminator_out_real.shape[1]
            ).to(DEVICE)
        )
        discriminator_loss_real.backward()
        discriminator_opt.step()

        # learn generated as fake
        discriminator.zero_grad()
        noise_dtrain = create_noise(10)
        noise_dtrain_with_classes = torch.cat(
            (
                noise_dtrain,
                one_hot
            ),
            dim=1
        )
        generator_out = generator(noise_dtrain_with_classes)
        out_with_classes = torch.cat((generator_out, oh_broadcasted), dim=1)
        discriminator_out = discriminator(out_with_classes)

        discriminator_loss_fake = bce(
            discriminator_out,
            torch.zeros(  # every fake image is a 0
                discriminator_out.shape[0],
                discriminator_out.shape[1]
            ).to(DEVICE)
        )

        discriminator_loss_fake.backward()
        discriminator_opt.step()

        # train generator

        generator.zero_grad()
        # recompute outputs to avoid backward 2 times error
        noise_gtrain = create_noise(0)
        generator_out = generator(noise_gtrain)
        out_with_classes = torch.cat((generator_out, oh_broadcasted), dim=1)
        discriminator_out = discriminator(out_with_classes)

        generator_loss = bce(
            discriminator_out,
            torch.ones(
                discriminator_out.shape[0],
                discriminator_out.shape[1]
            ).to(DEVICE)
        )

        generator_loss.backward()
        generator_opt.step()

    print('EPOCH:', epoch)
    noise_disp = create_noise(0)
    out = generator(noise_disp)
    plt.imshow(
        out[0].reshape(
            out.shape[2],
            out.shape[3]
        ).cpu().detach()
    )
    plt.show()

In [None]:
    BATCH_SIZE = 1

    label = torch.zeros(1, 10).to(DEVICE)
    label[0, 8] = 1

    noise_disp = create_noise(offset=10)
    generator.eval()
    with torch.no_grad():
        out = generator(torch.cat((noise_disp, label), dim=1))
        plt.imshow(
            out[0].reshape(
                out.shape[2],
                out.shape[3]
            ).cpu().detach()
        )
        plt.show()