Training of DCGAN network on MNIST dataset with Discriminator
and Generator imported from models.py

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialise_weights



In [2]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transforms, download=True
)

# comment mnist above and uncomment below if train on CelebA
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialise_weights(gen)
initialise_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [3]:
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{NUM_EPOCHS}] Batch {batch_idx+1}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
    

Epoch [1/5] Batch 1/469               Loss D: 0.7050, loss G: 0.7408
Epoch [1/5] Batch 2/469               Loss D: 0.6494, loss G: 0.8091
Epoch [1/5] Batch 3/469               Loss D: 0.6066, loss G: 0.8678
Epoch [1/5] Batch 4/469               Loss D: 0.5669, loss G: 0.9202
Epoch [1/5] Batch 5/469               Loss D: 0.5299, loss G: 0.9704
Epoch [1/5] Batch 6/469               Loss D: 0.4983, loss G: 1.0173
Epoch [1/5] Batch 7/469               Loss D: 0.4673, loss G: 1.0656
Epoch [1/5] Batch 8/469               Loss D: 0.4346, loss G: 1.1145
Epoch [1/5] Batch 9/469               Loss D: 0.4069, loss G: 1.1621
Epoch [1/5] Batch 10/469               Loss D: 0.3851, loss G: 1.2111
Epoch [1/5] Batch 11/469               Loss D: 0.3592, loss G: 1.2593
Epoch [1/5] Batch 12/469               Loss D: 0.3386, loss G: 1.3072
Epoch [1/5] Batch 13/469               Loss D: 0.3157, loss G: 1.3551
Epoch [1/5] Batch 14/469               Loss D: 0.2976, loss G: 1.4034
Epoch [1/5] Batch 15/469     

KeyboardInterrupt: 