In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [6]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self.block(features_d, features_d * 2,4,2,1),
            self.block(features_d * 2, features_d * 4, 4,2,1),
            self.block(features_d * 4, features_d * 8, 4,2,1),
            nn.Conv2d(features_d * 8, 1, kernel_size=4,stride=2, padding=0), # 1 x 1
            nn.Sigmoid()
            )
        

    def block(self, in_channels, out_channels,kernel_size, stride, padding):
        return nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.LeakyReLU(0.2, inplace=True))
    
    def forward(self, x):
        x = self.disc(x)
        print(x.shape)
        return x
    

class Generator(nn.Module):
    def __init__(self, z_dim, channel_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self.block(z_dim, features_g*16, 4, 1,0), # N x f_g * 4 * 4
            self.block(features_g*16, features_g*8, 4, 2,1), #8x8
            self.block(features_g*8, features_g*4, 4, 2,1), #16x16
            self.block(features_g*4, features_g*2, 4, 2,1), #32x32
            nn.ConvTranspose2d(features_g*2, channel_img, 4, 2,1),
            nn.Tanh() # Required for nomalization range   
        )

    def block(self, in_channels, out_channels,kernel_size, stride, padding):
        return nn.Sequential(
                    nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(0.2))
    
    def forward(self, x):
        return self.gen(x)
        

def initialise_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialise_weights(disc)
    assert disc(x).shape == (N, 1,1,1)
    gen = Generator(z_dim, in_channels, 8)
    initialise_weights(gen)
    z = torch.randn((N, z_dim,1,1))
    assert gen(z).shape == (N, in_channels, H, W)
    print('test passed')

test()

torch.Size([8, 1, 1, 1])
test passed


In [7]:
# Training loop
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

lr = 2e-4
batch_size = 128
img_sz = 64
channel_img = 1
epochs = 10
feature_disc = 64
feature_gen = 64
noise_dim = 100

# DataLoader 

transforms = transforms.Compose([
    transforms.Resize((img_sz)),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(channel_img)], [0.5 for _ in range(channel_img)]),
])

dataset = datasets.MNIST(root='../data', download=True, transform=transforms)
loader = DataLoader(dataset, batch_size=batch_size,shuffle = True)

gen = Generator(noise_dim,channel_img=channel_img,features_g=feature_gen).to(device)
disc = Discriminator(channel_img, feature_disc).to(device)

initialise_weights(gen)
initialise_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

fixed_noise = torch.randn(32, noise_dim,1,1).to(device)

writer_fake = SummaryWriter(f'../data/runs/DCGAN_MNIST/fake')
writer_real = SummaryWriter(f'../data/runs/DCGAN_MNIST/real')
step = 0

gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2

In [14]:
for epoch in range(epochs):
    for idx, (real_img, _) in enumerate(loader):
        real_img = real_img.to(device)
        noise = torch.randn((batch_size, noise_dim,1,1)).to(device)
        fake = gen(noise)

        ### Train Discriminator ###
        disc_real = disc(real_img).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator ###
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if idx == 0:
            print(
                f'Epoch [{epoch / epochs}] \  '
                f'Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}'
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1,64,64)
                data = real_img.reshape(-1, 1, 64, 64)

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    'MNIST Fake Images', img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    'MNIST Real Images', img_grid_real, global_step=step
                )

                step += 1


        

Epoch [0.0] \  Loss D: 0.4769, Loss G: 1.0590
Epoch [0.1] \  Loss D: 0.5955, Loss G: 1.0394
Epoch [0.2] \  Loss D: 0.5441, Loss G: 0.9291
Epoch [0.3] \  Loss D: 0.3725, Loss G: 1.7559
Epoch [0.4] \  Loss D: 0.2918, Loss G: 2.8822
Epoch [0.5] \  Loss D: 0.2024, Loss G: 3.6823
Epoch [0.6] \  Loss D: 0.0918, Loss G: 3.3084
Epoch [0.7] \  Loss D: 0.1167, Loss G: 3.4646
Epoch [0.8] \  Loss D: 0.0862, Loss G: 3.7097
Epoch [0.9] \  Loss D: 0.0338, Loss G: 4.1216
