## WGAN 훈련

### 라이브러리 임포트

In [None]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.utils as vutils
from torch.utils.data import Subset

from WGAN import Critic, Generator

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
batch_size = 64
epochs = 2000
z_dim = 100
D_lr = 5e-5
G_lr = 5e-5
n_critic = 5
clip_threshold = 0.01
data_path = '../data'
image_save_folder = './images/wgan'

os.makedirs(image_save_folder, exist_ok=True)

### 데이터 적재

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

specific_label = 7

trainset = torchvision.datasets.CIFAR10(root=data_path, train=True,
                                        download=True, transform=transform)
train_indices = [idx for idx, target in enumerate(trainset.targets) if target is specific_label]
trainset = trainset.data[train_indices]

testset = torchvision.datasets.CIFAR10(root=data_path, train=False,
                                       download=True, transform=transform)
test_indices = [idx for idx, target in enumerate(testset.targets) if target is specific_label]
testset = testset.data[test_indices]

dataset = np.concatenate((trainset, testset), axis=0)
dataset = np.transpose(dataset, (0, 3, 1, 2))
dataset = (dataset - 127.5) / 127.5
dataset = torch.from_numpy(dataset).float()

dataset_size = len(dataset)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2, drop_last=True)

classes = ['plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
sample_image_batch_size = 25

def imshow(image):
    plt.figure(figsize=(18, 18))
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid((image + 1) * 0.5,
                                             nrow=5,
                                             padding=2,
                                             normalize=True), (1, 2, 0)))
    plt.savefig(os.path.join(image_save_folder, 'sample_images.png'))
    
images = next(iter(dataloader))
imshow(images[:sample_image_batch_size])

### 모델 생성

In [None]:
generator = Generator(z_dim).to(device)
critic = Critic().to(device)
generator.train()
critic.train()

In [None]:
print(generator)

In [None]:
print(critic)

### 모델 훈련

In [None]:
real_label = 1.
fake_label = -1.

D_optimizer = optim.RMSprop(params=critic.parameters(), lr=D_lr)
G_optimizer = optim.RMSprop(params=generator.parameters(), lr=G_lr)
# D_optimizer = optim.Adam(params=critic.parameters(), lr=D_lr, betas=[0.5, 0.999])
# G_optimizer = optim.Adam(params=generator.parameters(), lr=G_lr, betas=[0.5, 0.999])

In [None]:
G_losses = []
D_losses = []
D_losses_real = []
D_losses_fake = []

one = torch.FloatTensor([1]).to(device)
mone = one * -1

num_G_input_batches = (len(dataloader) // n_critic) * batch_size

for epoch in range(epochs):
    epoch_G_loss = 0.0
    epoch_D_loss = 0.0
    epoch_D_loss_real = 0.0
    epoch_D_loss_fake = 0.0
    num_inputs = 0
    
    for i, inputs in enumerate(dataloader):
        critic.zero_grad()
        
        # 진짜 데이터로 학습
        inputs = inputs.to(device)
        
        output = critic(inputs)
        D_loss_real = -output.mean().view(-1)
        
        # 가짜 데이터로 학습
        noise = torch.randn(batch_size, z_dim, device=device)
        fake = generator(noise)
        output = critic(fake)
        D_loss_fake = output.mean().view(-1)
        
        D_loss = D_loss_real + D_loss_fake
        D_loss.backward()
        D_optimizer.step()
        
        epoch_D_loss += D_loss.item() * batch_size
        epoch_D_loss_real += D_loss_real.item() * batch_size
        epoch_D_loss_fake += D_loss_fake.item() * batch_size
        
        num_inputs += inputs.size(0)
        
        # 1-Lipshitz continuous function
        for p in critic.parameters():
            p.data.clamp_(-clip_threshold, clip_threshold)
            
        if (i + 1) % n_critic == 0:                    
            for p in critic.parameters():
                p.requires_grad = False
            
            # Generator
            generator.zero_grad()
            noise = torch.randn(batch_size, z_dim, device=device)
            fake = generator(noise)
            output = critic(fake)
            G_loss = -output.mean().view(-1)
            G_loss.backward()
            G_optimizer.step()
            
            epoch_G_loss += G_loss.item() * batch_size
            
            for p in critic.parameters():
                p.requires_grad = True
        
    epoch_D_loss /= num_inputs
    epoch_D_loss_real /= num_inputs
    epoch_D_loss_fake /= num_inputs
    epoch_G_loss /= num_G_input_batches
    
    D_losses.append(epoch_D_loss)
    D_losses_real.append(epoch_D_loss_real)
    D_losses_fake.append(epoch_D_loss_fake)
    G_losses.append(epoch_G_loss)
    
    print('%d [D loss: (%.3f)(R %.3f, F %.3f)] [G loss: %.3f]' %
          (epoch + 1, epoch_D_loss, epoch_D_loss_real, epoch_D_loss_fake, epoch_G_loss))
    
    if epoch + 1 in [50, 100, 200, 500, 1000, 2000]:
        plt.figure(figsize=(20, 20))
        plt.axis("off")
        plt.imshow(np.transpose(vutils.make_grid((fake[:25] + 1) * 0.5,
                                                 nrow=5,
                                                 padding=2,
                                                 normalize=False).detach().cpu(), (1, 2, 0)))
        plt.savefig(os.path.join(image_save_folder, f'epoch_{epoch + 1}.png'))


In [None]:
fig = plt.figure(figsize=(20, 10))

plt.plot([x for x in D_losses], color='black', linewidth=0.25)
plt.plot([x for x in D_losses_real], color='green', linewidth=0.25)
plt.plot([x for x in D_losses_fake], color='red', linewidth=0.25)
plt.plot([x for x in G_losses], color='orange', linewidth=0.25)

plt.xlabel('epoch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)

plt.savefig(os.path.join(image_save_folder, f'loss_graph.png'))

In [None]:
def compare_images(img1, img2):
    return torch.mean(torch.abs(img1 - img2))

In [None]:
noise = torch.randn(25, z_dim, device=device)

gen_imgs = generator(noise).detach().cpu()

plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(gen_imgs / 2 + 0.5,
                                         nrow=5,
                                         padding=2,
                                         normalize=False), (1, 2, 0)))
plt.savefig(os.path.join(image_save_folder, f'gen_imgs.png'))

In [None]:
closest_imgs = []

for gen_img in gen_imgs:
    min_val = 9999
    min_img = dataset[0]
    for real_imgs in dataloader:
        for real_img in real_imgs:
            real_img = real_img.numpy()
            diff = compare_images(gen_img, real_img)
            if min_val > diff:
                min_val = diff
                min_img = np.copy(real_img)
        
    closest_imgs.append(min_img)

plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid((torch.FloatTensor(closest_imgs) + 1) * 0.5,
                                         nrow=5,
                                         padding=2,
                                         normalize=False), (1, 2, 0)))
plt.savefig(os.path.join(image_save_folder, f'closest_real_imgs.png'))

