In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from torchvision.utils import make_grid

from WGAN_GP_pytorch import Critic, Generator, gradient_penalty

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMG_SIZE = 64
IMG_CHANNEL = 1
Z_DIM = 100
EPOCHS = 5
CRITIC_FEATURES = 16
GENERATOR_FEATURES = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

In [3]:
data_loader = DataLoader(
    dataset=datasets.MNIST(
        root='../datasets',
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.5 for _ in range(IMG_CHANNEL)], [0.5 for _ in range(IMG_CHANNEL)])
        ])
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [4]:
critic = Critic(IMG_CHANNEL, CRITIC_FEATURES).to(DEVICE)
generator = Generator(Z_DIM, IMG_CHANNEL, GENERATOR_FEATURES).to(DEVICE)

critic_optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

In [5]:
writer_real = SummaryWriter('logs/wgan_gp/real')
writer_fake = SummaryWriter('logs/wgan_gp/fake')

step = 0

critic.train()
generator.train()

for epoch in range(EPOCHS):
    for batch_idx, (real, _) in enumerate(data_loader):
        real = real.to(DEVICE)
        for _ in range(CRITIC_ITERATIONS):
            z = torch.randn(real.shape[0], Z_DIM, 1, 1).to(DEVICE)
            gen_img = generator(z)
            
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(gen_img).reshape(-1)
            
            gp = gradient_penalty(critic, real, gen_img, device=DEVICE)
            
            critic_loss = -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp
            critic.zero_grad()
            critic_loss.backward(retain_graph=True)
            critic_optimizer.step()


        critic_fake = critic(gen_img).reshape(-1)
        generator_loss = -torch.mean(critic_fake)
        generator.zero_grad()
        generator_loss.backward()
        generator_optimizer.step()
        
        if batch_idx % 100 == 0 and batch_idx > 0:
            critic.eval()
            generator.eval()
            print(f'Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(data_loader)} Critic_loss: {critic_loss:.4f}, Generator_loss: {generator_loss:.4f}')
            
            with torch.no_grad():
                gen_img = generator(z)
                img_grid_real = make_grid(real[:32], normalize=True)
                img_grid_fake = make_grid(gen_img[:32], normalize=True)
                
                writer_real.add_image('real', img_grid_real, global_step=step)
                writer_fake.add_image('fake', img_grid_fake, global_step=step)
                
            step += 1
            
            generator.train()
            critic.train()

Epoch [0/5] Batch 100/938 Critic_loss: -31.3290, Generator_loss: 6.8070
Epoch [0/5] Batch 200/938 Critic_loss: -20.9275, Generator_loss: 9.4686
Epoch [0/5] Batch 300/938 Critic_loss: -16.4957, Generator_loss: 15.0448
Epoch [0/5] Batch 400/938 Critic_loss: -12.1411, Generator_loss: 18.2889
Epoch [0/5] Batch 500/938 Critic_loss: -10.9591, Generator_loss: 18.6333
Epoch [0/5] Batch 600/938 Critic_loss: -10.5856, Generator_loss: 22.2330
Epoch [0/5] Batch 700/938 Critic_loss: -9.3268, Generator_loss: 23.6651
Epoch [0/5] Batch 800/938 Critic_loss: -9.4234, Generator_loss: 24.3751
Epoch [0/5] Batch 900/938 Critic_loss: -10.5744, Generator_loss: 22.2779
Epoch [1/5] Batch 100/938 Critic_loss: -9.3547, Generator_loss: 25.0233
Epoch [1/5] Batch 200/938 Critic_loss: -9.4256, Generator_loss: 24.2712
Epoch [1/5] Batch 300/938 Critic_loss: -8.9973, Generator_loss: 22.5422
Epoch [1/5] Batch 400/938 Critic_loss: -9.2134, Generator_loss: 26.9146
Epoch [1/5] Batch 500/938 Critic_loss: -8.0064, Generator_l