In [2]:
import os
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

In [58]:
class Discriminator(nn.Module):
    def __init__(self, img_size, num_blocks):
        super(Discriminator, self).__init__()
        layers = [
            nn.Conv2d(3, img_size, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        ]
        for i in range(num_blocks):
            layers.append(
                self._block(img_size*(2**i), img_size*(2**(i+1)), 4, 2, 1)
            )
        layers.extend([
            nn.Conv2d(img_size*(2**(i+1)), 1, 4, 2, 0),
            nn.Sigmoid()
        ])
        self.net = nn.Sequential(*layers)
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.net(x)

In [59]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_size, num_blocks):
        super(Generator, self).__init__()
        layers = [
            self._block(z_dim, img_size*(2**num_blocks), 4, 1, 0)
        ]
        for i in range(num_blocks, 1, -1):
            layers.append(
                self._block(img_size*(2**i), img_size*(2**(i-1)), 4, 2, 1)
            )
        layers.extend([
            nn.ConvTranspose2d(img_size*2, 3, 4, 2, 1),
            nn.Tanh()
        ])
        self.net = nn.Sequential(*layers)
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.net(x)

In [60]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [90]:
def test():
    N = 8
    z_dim = 100
    x = torch.randn((N, 3, IMAGE_SIZE, IMAGE_SIZE))
    disc = Discriminator(IMAGE_SIZE, 3)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)
    
    gen = Generator(z_dim, IMAGE_SIZE, 4)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, 3, IMAGE_SIZE, IMAGE_SIZE)

test()

In [114]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 2e-4
Z_DIM = 100
NUM_EPOCHS = 10
BATCH_SIZE = 128

In [115]:
def train(img_size, dataloader, num_blocks_gen, num_blocks_disc):
    if os.path.exists(f'logs/{img_size}'):
        shutil.rmtree(f'logs/{img_size}')
    step = 0
    writer = SummaryWriter(f'logs/{img_size}')
    writer_real = SummaryWriter(f'logs/{img_size}/real')
    writer_fake = SummaryWriter(f'logs/{img_size}/fake')

    gen = Generator(Z_DIM, img_size, num_blocks_gen).to(device)
    initialize_weights(gen)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    disc = Discriminator(img_size, num_blocks_disc).to(device)
    initialize_weights(disc)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(16, Z_DIM, 1, 1).to(device)

    gen.train()
    disc.train()

    for epoch in range(NUM_EPOCHS):
        for batch_idx, (real, _) in enumerate(dataloader):
            real = real.to(device)
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)

            # train discriminator
            disc_real = disc(real).reshape(-1)
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake.detach()).reshape(-1)
            loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            loss_disc = (loss_disc_real + loss_disc_fake) / 2
            disc.zero_grad()
            loss_disc.backward()
            opt_disc.step()

            # train generator
            output = disc(fake).reshape(-1)
            loss_gen = criterion(output, torch.ones_like(output))
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            writer.add_scalar(f'Loss/Discriminator Loss', loss_disc.item(), global_step=step)
            writer.add_scalar(f'Loss/Generator Loss', loss_gen.item(), global_step=step)
            writer.add_scalars(f'Comb_Loss/Losses', {
                'Discriminator': loss_disc.item(),
                'Generator':  loss_gen.item()
            }, step)    

            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(trainloader)}\nLossD: {loss_disc:.4f}, LossG: {loss_gen:.4f}\n")

                with torch.no_grad():
                    fake = gen(fixed_noise)

                    img_grid_real = torchvision.utils.make_grid(real[:16], normalize=True)
                    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)

                    writer_real.add_image('Real', img_grid_real, global_step=step)
                    writer_fake.add_image('Fake', img_grid_fake, global_step=step)

            step += 1


In [116]:
IMAGE_SIZE = 32
transform = transforms.Compose(
    [
     transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader_32 = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

Files already downloaded and verified


In [117]:
train(IMAGE_SIZE, trainloader_32, 3, 2)

Epoch [0/10] Batch 0/391
LossD: 0.6910, LossG: 0.7095

Epoch [0/10] Batch 100/391
LossD: 0.1574, LossG: 2.0112

Epoch [0/10] Batch 200/391
LossD: 0.0658, LossG: 3.1542

Epoch [0/10] Batch 300/391
LossD: 0.3034, LossG: 2.1225

Epoch [1/10] Batch 0/391
LossD: 0.4970, LossG: 1.8705

Epoch [1/10] Batch 100/391
LossD: 0.5904, LossG: 1.1150

Epoch [1/10] Batch 200/391
LossD: 0.4490, LossG: 1.4704

Epoch [1/10] Batch 300/391
LossD: 0.6010, LossG: 1.6019

Epoch [2/10] Batch 0/391
LossD: 0.5286, LossG: 1.2366

Epoch [2/10] Batch 100/391
LossD: 0.5254, LossG: 1.3205

Epoch [2/10] Batch 200/391
LossD: 0.5388, LossG: 1.2472

Epoch [2/10] Batch 300/391
LossD: 0.5082, LossG: 1.2911

Epoch [3/10] Batch 0/391
LossD: 0.4895, LossG: 1.2482

Epoch [3/10] Batch 100/391
LossD: 0.4425, LossG: 1.4490

Epoch [3/10] Batch 200/391
LossD: 0.4804, LossG: 1.4477

Epoch [3/10] Batch 300/391
LossD: 0.3768, LossG: 1.7165

Epoch [4/10] Batch 0/391
LossD: 0.4379, LossG: 1.6753

Epoch [4/10] Batch 100/391
LossD: 0.4921,

In [118]:
IMAGE_SIZE = 64

transform = transforms.Compose(
    [
     transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader_64 = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

Files already downloaded and verified


In [119]:
train(IMAGE_SIZE, trainloader_64, 4, 3)

Epoch [0/10] Batch 0/391
LossD: 0.6881, LossG: 0.8069

Epoch [0/10] Batch 100/391
LossD: 0.0325, LossG: 3.6405

Epoch [0/10] Batch 200/391
LossD: 0.2608, LossG: 2.3286

Epoch [0/10] Batch 300/391
LossD: 0.4734, LossG: 1.4808

Epoch [1/10] Batch 0/391
LossD: 0.4822, LossG: 1.5392

Epoch [1/10] Batch 100/391
LossD: 0.7259, LossG: 0.8036

Epoch [1/10] Batch 200/391
LossD: 0.5151, LossG: 1.6969

Epoch [1/10] Batch 300/391
LossD: 0.5888, LossG: 1.5069

Epoch [2/10] Batch 0/391
LossD: 0.5908, LossG: 1.1664

Epoch [2/10] Batch 100/391
LossD: 0.6890, LossG: 1.4918

Epoch [2/10] Batch 200/391
LossD: 0.5097, LossG: 1.5576

Epoch [2/10] Batch 300/391
LossD: 0.5803, LossG: 1.5948

Epoch [3/10] Batch 0/391
LossD: 0.6765, LossG: 1.1348

Epoch [3/10] Batch 100/391
LossD: 0.5655, LossG: 1.2603

Epoch [3/10] Batch 200/391
LossD: 0.5615, LossG: 1.0650

Epoch [3/10] Batch 300/391
LossD: 0.6415, LossG: 0.8331

Epoch [4/10] Batch 0/391
LossD: 0.5932, LossG: 1.7194

Epoch [4/10] Batch 100/391
LossD: 0.5089,

In [123]:
Discriminator(img_size=64, num_blocks=3)

Discriminator(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [125]:
Generator(z_dim=100, img_size=64, num_blocks=4)

Generator(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 3, 

In [126]:
Discriminator(img_size=32, num_blocks=2)

Discriminator(
  (net): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2))
    (5): Sigmoid()
  )
)

In [127]:
Generator(z_dim=100, img_size=32, num_blocks=3)

Generator(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Tanh()
  )
)