In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from model import Generator, Discriminator
from utils import init_weights_normal, visualize_outputs
from torch.utils.tensorboard import SummaryWriter

In [2]:
T = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = MNIST(root="./data", train=True, download=True, transform=T)

In [3]:
EPOCHS = 1000
BATCH_SIZE = 128
DEVICE = torch.device("mps")
LR = 0.0002
BETAS = (0.5, 0.999)
Z_DIM = 100

In [4]:
loader = DataLoader(dataset, BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM).to(DEVICE)
disc = Discriminator().to(DEVICE)
init_weights_normal(gen, 0, 0.02)
init_weights_normal(disc, 0, 0.02)

gen_opt = optim.Adam(gen.parameters(), lr=LR, betas=BETAS)
disc_opt = optim.Adam(disc.parameters(), lr=LR, betas=BETAS)
crit = nn.BCELoss()

In [5]:
writer = SummaryWriter(log_dir="logs")

for e in range(EPOCHS):
    loop = tqdm(enumerate(loader), total=len(loader), leave=True, position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    gen.train()
    disc.train()
    for idx, (real, _) in loop:
        real = real.to(DEVICE)
        noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(DEVICE)

        # Discriminator training
        disc_real_yhat = disc(real).view(-1)
        loss_disc_real = crit(disc_real_yhat, torch.ones_like(disc_real_yhat))

        fake = gen(noise)
        disc_fake_yhat = disc(fake).view(-1)
        loss_disc_fake = crit(disc_fake_yhat, torch.zeros_like(disc_fake_yhat))

        disc_loss = (loss_disc_fake + loss_disc_real) / 2
        disc.zero_grad()
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        # Generator training
        fake_yhat = disc(fake).view(-1)
        gen_loss = crit(fake_yhat, torch.ones_like(fake_yhat))
        gen.zero_grad()
        gen_loss.backward()
        gen_opt.step()

        loop.set_postfix(gen_loss = gen_loss.item(), disc_loss = disc_loss.item())
        writer.add_scalar("Generator Loss", gen_loss.item(), idx + len(loader) * e)
        writer.add_scalar("Discriminator Loss", disc_loss.item(), idx + len(loader) * e)

        writer.add_images("Sample outputs", visualize_outputs(gen, Z_DIM, DEVICE), idx + len(loader) * e)

Epoch : [0/1000]:   1%|          | 5/469 [00:06<09:26,  1.22s/it, disc_loss=0.577, gen_loss=0.858]


KeyboardInterrupt: 