In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from torchvision.utils import make_grid

from WGAN_pytorch import Critic, Generator

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMG_SIZE = 64
IMG_CHANNEL = 1
Z_DIM = 100
EPOCHS = 5
CRITIC_FEATURES = 32
GENERATOR_FEATURES = 32
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

In [3]:
data_loader = DataLoader(
    dataset=datasets.MNIST(
        root='../datasets',
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.5 for _ in range(IMG_CHANNEL)], [0.5 for _ in range(IMG_CHANNEL)])
        ])
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [4]:
critic = Critic(IMG_CHANNEL, CRITIC_FEATURES).to(DEVICE)
generator = Generator(Z_DIM, IMG_CHANNEL, GENERATOR_FEATURES).to(DEVICE)

critic_optimizer = torch.optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)
generator_optimizer = torch.optim.RMSprop(generator.parameters(), lr=LEARNING_RATE)

In [5]:
writer_real = SummaryWriter('logs/wgan/real')
writer_fake = SummaryWriter('logs/wgan/fake')

step = 0

critic.train()
generator.train()

for epoch in range(EPOCHS):
    for batch_idx, (real, _) in enumerate(data_loader):
        real = real.to(DEVICE)
        for _ in range(CRITIC_ITERATIONS):
            z = torch.randn(real.shape[0], Z_DIM, 1, 1).to(DEVICE)
            gen_img = generator(z)
            
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(gen_img).reshape(-1)
            
            critic_loss = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            critic_loss.backward(retain_graph=True)
            critic_optimizer.step()
            
            for p in critic.parameters():
                p.data.clip_(-WEIGHT_CLIP, WEIGHT_CLIP)
                
        critic_fake = critic(gen_img).reshape(-1)
        generator_loss = -torch.mean(critic_fake)
        generator.zero_grad()
        generator_loss.backward()
        generator_optimizer.step()
        
        if batch_idx % 100 == 0 and batch_idx > 0:
            critic.eval()
            generator.eval()
            print(f'Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(data_loader)} Critic_loss: {critic_loss:.4f}, Generator_loss: {generator_loss:.4f}')
            
            with torch.no_grad():
                gen_img = generator(z)
                img_grid_real = make_grid(real[:32], normalize=True)
                img_grid_fake = make_grid(gen_img[:32], normalize=True)
                
                writer_real.add_image('real', img_grid_real, global_step=step)
                writer_fake.add_image('fake', img_grid_fake, global_step=step)
                
            step += 1
            
            generator.train()
            critic.train()

Epoch [0/5] Batch 100/938 Critic_loss: -0.3831, Generator_loss: 0.1926
Epoch [0/5] Batch 200/938 Critic_loss: -0.3731, Generator_loss: 0.1423
Epoch [0/5] Batch 300/938 Critic_loss: -0.2906, Generator_loss: 0.2912
Epoch [0/5] Batch 400/938 Critic_loss: -0.3808, Generator_loss: 0.0975
Epoch [0/5] Batch 500/938 Critic_loss: -0.3934, Generator_loss: 0.2984
Epoch [0/5] Batch 600/938 Critic_loss: -0.5114, Generator_loss: 0.2973
Epoch [0/5] Batch 700/938 Critic_loss: -0.3401, Generator_loss: -0.0094
Epoch [0/5] Batch 800/938 Critic_loss: -0.3816, Generator_loss: 0.1272
Epoch [0/5] Batch 900/938 Critic_loss: -0.4180, Generator_loss: 0.1357
Epoch [1/5] Batch 100/938 Critic_loss: -0.3816, Generator_loss: 0.0784
Epoch [1/5] Batch 200/938 Critic_loss: -0.3919, Generator_loss: 0.2955
Epoch [1/5] Batch 300/938 Critic_loss: -0.4767, Generator_loss: 0.2704
Epoch [1/5] Batch 400/938 Critic_loss: -0.3531, Generator_loss: 0.2926
Epoch [1/5] Batch 500/938 Critic_loss: -0.3849, Generator_loss: 0.2720
Epoch