In [13]:
import wandb
import random
import numpy as np

import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

# Simple Feed-forward GAN
class Generator(nn.Module):
    def __init__(self, z_dim=20, out_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, in_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# Data Preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# Initialize Networks
z_dim = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = Generator(z_dim=z_dim).to(device)
D = Discriminator().to(device)

# Training Loop
num_epochs = 50
# Initialize Optimizers
lr = 0.0002
G_optimizer = Adam(G.parameters(), lr=lr)
D_optimizer = Adam(D.parameters(), lr=lr)

# Initialize Loss Function
criterion = nn.BCELoss()

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    project="mnist-gan",
    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "epochs": num_epochs,
        "z_dim": z_dim,
    }
)

fixed_noise = torch.randn(100, z_dim).to(device)
for epoch in range(num_epochs):
    g_loss_avg = 0.0
    d_loss_avg = 0.0
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.view(real_images.size(0), -1).to(device)
        batch_size = real_images.size(0)

        # Train Discriminator
        D_real = D(real_images)
        D_real_loss = criterion(D_real, torch.ones(batch_size, 1).to(device))

        z = torch.randn(batch_size, z_dim).to(device)
        fake_images = G(z)
        D_fake = D(fake_images)
        D_fake_loss = criterion(D_fake, torch.zeros(batch_size, 1).to(device))

        D_loss = D_real_loss + D_fake_loss

        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # Train Generator
        z = torch.randn(batch_size, z_dim).to(device)
        fake_images = G(z)
        D_fake = D(fake_images)
        G_loss = criterion(D_fake, torch.ones(batch_size, 1).to(device))

        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        wandb.log({"Losses": {"D_Loss": D_loss.item(), "G_Loss": G_loss.item()}}, step=epoch * len(dataloader) + i)
        d_loss_avg += D_loss.item()
        g_loss_avg += G_loss.item()

    print(f'[{epoch+1}/{num_epochs}] D_loss: {d_loss_avg/len(dataloader):.4f} G_loss: {g_loss_avg/len(dataloader):.4f}')
    fake_images = G(fixed_noise)
    img_grid = make_grid(fake_images.view(-1, 1, 28, 28), nrow=10, normalize=True)
    wandb.log({"examples": [wandb.Image(img_grid)]}, step=epoch)

    # Save models every 10 epochs
    # if (epoch+1) % 10 == 0:
    FOLDER = 'data/local/mnistModels'
    generator_artifact = wandb.Artifact('generator', type='model', description='Generator Model')
    torch.save(G.state_dict(), FOLDER + '/generator.pth')
    generator_artifact.add_file(FOLDER + '/generator.pth')
    run.log_artifact(generator_artifact)
    discriminator_artifact = wandb.Artifact('discriminator', type='model', description='Discriminator Model')
    torch.save(D.state_dict(), FOLDER + '/discriminator.pth')
    discriminator_artifact.add_file(FOLDER + '/discriminator.pth')
    run.log_artifact(discriminator_artifact)

# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

[1/50] D_loss: 0.7009 G_loss: 1.3495
[2/50] D_loss: 0.7604 G_loss: 1.2977
[3/50] D_loss: 0.7405 G_loss: 1.3603
[4/50] D_loss: 0.7994 G_loss: 1.3257
[5/50] D_loss: 0.9728 G_loss: 1.1402
[6/50] D_loss: 1.0511 G_loss: 1.1173
[7/50] D_loss: 0.7863 G_loss: 1.3897
[8/50] D_loss: 1.1030 G_loss: 1.0216


KeyboardInterrupt: 

In [12]:
FOLDER = 'data/local/mnistModels'
generator_artifact = wandb.Artifact('generator', type='model', description='Generator Model')
torch.save(G.state_dict(), FOLDER + '/generator.pth')
generator_artifact.add_file(FOLDER + '/generator.pth')
run.log_artifact(generator_artifact)
discriminator_artifact = wandb.Artifact('discriminator', type='model', description='Discriminator Model')
torch.save(D.state_dict(), FOLDER + '/discriminator.pth')
discriminator_artifact.add_file(FOLDER + '/discriminator.pth')
run.log_artifact(discriminator_artifact)


UsageError: Run (klwdrxgs) is finished. The call to `log_artifact` will be ignored. Please make sure that you are using an active run.

In [10]:
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()