In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# from model_utils import Discriminator, Generator # inspired from DCGAN paper

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # N x features_d x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_d*2),
            nn.LeakyReLU(0.2),
            # N x features_d x 32 x 32
            nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_d*4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_d*8),
            nn.LeakyReLU(0.2),
            # N x features*8 x 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
            # N x 1 x 1 x 1
            nn.Sigmoid()
        )
        

    def forward(self, x):
        return self.net(x)


In [4]:
class Generator(nn.Module):
    # param channels_noise: number of channels of starting image noise to begin generation from
    # param channels_img: number of channels of final image being generated
    # param features_g:
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            # N x channels_noise x 1 x 1
            nn.ConvTranspose2d(channels_noise, features_g*16, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(features_g*16),
            nn.ReLU(),
            # N x features_g*16 x 4 x 4
            nn.ConvTranspose2d(features_g*16, features_g*8, kernel_size=4, stride=2, padding=1), # stride = 2 doubles the size of the input
            nn.BatchNorm2d(features_g*8),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1), # stride = 2 doubles the size of the input
            nn.BatchNorm2d(features_g*4),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1), # stride = 2 doubles the size of the input
            nn.BatchNorm2d(features_g*2),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
            # N x channels_img x 64 x 64
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)

In [8]:
# hyperparameters
LR = 0.0002
BS = 64
IMAGE_SIZE = 64 # MNIST (28x28) -> 64x64
CHANNELS_IMG = 1 # MNIST grayscale numbers, 1 channel.
CHANNELS_NOISE = 256
EPOCHS = 10
r_label = 1
f_label = 0

# paper does 64, larger network, not needed for MNIST, use for face gen
features_d = 16
features_g = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on GPU" if torch.cuda.is_available() else print("Running on CPU (SLOW)"))

my_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
dataset = datasets.MNIST(root='dataset/', train=True, transform=my_transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BS, shuffle=True)
num_batches = len(dataloader)
print(f"Batches: {num_batches}")

fixed_noise = torch.randn(64, CHANNELS_NOISE, 1, 1).to(device) # used as test input as network improves

loss_function = nn.BCELoss()

class Net():

    # validation tensorboard
    writer_real = SummaryWriter(f"runs/GAN_MNIST/test_real")
    writer_fake = SummaryWriter(f"runs/GAN_MNIST/test_fake")

    def main(self):
        self.Dnet = Discriminator(CHANNELS_IMG, features_d).to(device)
        self.Gnet = Generator(CHANNELS_NOISE, CHANNELS_IMG, features_g).to(device)
        self.optimizerD = optim.Adam(self.Dnet.parameters(), lr=LR, betas=(0.5, 0.9999))
        self.optimizerG = optim.Adam(self.Gnet.parameters(), lr=LR, betas=(0.5, 0.9999))
        self.train()

    def train(self):
        for epoch in range(EPOCHS):
            for batch_idx, (data, targets) in enumerate(dataloader):
                data = data.to(device)
                batch_size = data.shape[0]

                # train discriminator: max log[D(x)] + log(1 - D[G(z)])

                # train discrimintator on real MNIST data, classify as real.
                self.Dnet.zero_grad()
                label = (torch.ones(batch_size)*0.9).to(device) # labels for training all REAL images, 1 label.
                output = self.Dnet(data).reshape(-1) # pass REAL images
                output = output[:64]
                lossD_real = loss_function(output, label)
                D_x = output.mean().item() # mean prediction for all images per batch

                # generate fake Gnet images, train discrimintaor on fakes to classify as fakes.
                noise = torch.randn(batch_size, CHANNELS_NOISE, 1, 1).to(device)
                fake = self.Gnet(noise).to(device) # outputs generated fakes

                label = (torch.ones(batch_size)*0.1).to(device) # fake-distinguishing labels
                output = self.Dnet(fake.detach()).reshape(-1) # detaching fake images from gradients, stops generator from training on fake data.
                output = output[:64]
                lossD_fake = loss_function(output, label)

                # combine (real image loss) + (fake image loss) and backpropagate, train discriminator.
                lossD = lossD_real + lossD_fake
                lossD.backward()
                self.optimizerD.step()

                # train generator: max log[D(G(z))]
                self.Gnet.zero_grad()
                label = torch.ones(batch_size).to(device) # no generalization multiplication, generator wants to be as accurate as it can be.
                output = self.Dnet(fake).reshape(-1) # get prediction from descriminator
                output = output[:64]
                lossG = loss_function(output, label)
                lossG.backward()
                self.optimizerG.step() # take answers from discriminator, apply to improve generator.

                if batch_idx % 100 == 0:
                    # every 100 steps
                    print(f"[e{epoch}/{EPOCHS}] [b{batch_idx}/{num_batches}] lossD: {lossD:.4f}, lossG: {lossG:.4f}, D(x): {D_x:.4f}")

                    with torch.no_grad():
                        fake = self.Gnet(fixed_noise)
                        img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
                        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                        self.writer_real.add_image("MNIST Real Images", img_grid_real)
                        self.writer_fake.add_image("MNIST Fake Images", img_grid_fake)


if __name__ == "__main__":
    net = Net()
    net.main()

Running on GPU
Batches: 938
[e0/10] [b0/938] lossD: 1.4448, lossG: 0.8309, D(x): 0.5351
[e0/10] [b100/938] lossD: 0.6863, lossG: 2.4938, D(x): 0.8328
[e0/10] [b200/938] lossD: 0.7002, lossG: 3.4931, D(x): 0.8649
[e0/10] [b300/938] lossD: 0.9801, lossG: 2.3773, D(x): 0.7922
[e0/10] [b400/938] lossD: 0.8193, lossG: 2.1091, D(x): 0.8244
[e0/10] [b500/938] lossD: 0.8174, lossG: 1.4320, D(x): 0.7110
[e0/10] [b600/938] lossD: 0.7788, lossG: 1.6278, D(x): 0.8347
[e0/10] [b700/938] lossD: 0.9507, lossG: 1.8831, D(x): 0.8843
[e0/10] [b800/938] lossD: 0.7929, lossG: 1.6713, D(x): 0.8093
[e0/10] [b900/938] lossD: 0.8150, lossG: 1.0476, D(x): 0.7167
[e1/10] [b0/938] lossD: 0.8399, lossG: 1.5816, D(x): 0.7269
[e1/10] [b100/938] lossD: 0.7721, lossG: 1.7401, D(x): 0.8041
[e1/10] [b200/938] lossD: 0.9087, lossG: 2.1355, D(x): 0.8711
[e1/10] [b300/938] lossD: 0.8008, lossG: 1.7515, D(x): 0.7953
[e1/10] [b400/938] lossD: 0.9490, lossG: 1.1562, D(x): 0.6901
[e1/10] [b500/938] lossD: 0.8305, lossG: 2.662