In [4]:
from torch.utils.data import Dataset, DataLoader
import torchvision
import numpy as np
from torchvision import transforms
from torch import nn
import torch
from torch import optim
import tensorboard
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as datasets

torch.cuda.empty_cache()

In [5]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1),
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 8, 1, 4, 2, 0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(

            nn.ConvTranspose2d(channels_noise, features_g * 16, 4, 1, 0),
            nn.BatchNorm2d(features_g * 16),
            nn.ReLU(),  
            nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(),  
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),  
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),  
            nn.ConvTranspose2d(features_g * 2, channels_img, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m,nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)


In [6]:
# Hyperparameters 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 2e-4
bs = 128
img_size = 64
no_channels = 3
noise_size = 100
epochs = 5
fd = 64
fg = 64

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(img_size),
        transforms.Normalize(
            [0.5 for _ in range(no_channels)], [0.5 for _ in range(no_channels)]
        ),
    ]
)


dataset = datasets.ImageFolder(root='Bitmoji-Faces', transform=transform)


dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
gen = Generator(noise_size, no_channels, fg).to(device)
disc = Discriminator(no_channels, fd).to(device)

initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

# for tensorboard plotting
fixed_noise = torch.randn(32, noise_size, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

const_z = torch.randn(bs, noise_size, 1, 1).to(device)

for epoch in range(epochs):
    for batch_idx, (real,_) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(real.shape[0], noise_size, 1, 1).to(device)
        fake = gen(noise)

        # max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(const_z)
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/5] Batch 0/1018                   Loss D: 0.6370, loss G: 3.9072
Epoch [0/5] Batch 100/1018                   Loss D: 0.3474, loss G: 6.3488
Epoch [0/5] Batch 200/1018                   Loss D: 0.2229, loss G: 3.6310
Epoch [0/5] Batch 300/1018                   Loss D: 0.6767, loss G: 0.7573
Epoch [0/5] Batch 400/1018                   Loss D: 0.4375, loss G: 1.5405
Epoch [0/5] Batch 500/1018                   Loss D: 0.3748, loss G: 2.1263
Epoch [0/5] Batch 600/1018                   Loss D: 0.5145, loss G: 2.2741
Epoch [0/5] Batch 700/1018                   Loss D: 0.3502, loss G: 2.7334
Epoch [0/5] Batch 800/1018                   Loss D: 0.3376, loss G: 3.6083
Epoch [0/5] Batch 900/1018                   Loss D: 0.3652, loss G: 3.6951
Epoch [0/5] Batch 1000/1018                   Loss D: 0.7770, loss G: 4.9429
Epoch [1/5] Batch 0/1018                   Loss D: 0.3682, loss G: 1.7557
Epoch [1/5] Batch 100/1018                   Loss D: 0.3116, loss G: 3.8016
Epoch [1/5] Bat