## GAN 훈련

### 라이브러리 임포트

In [None]:
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
import torch.utils.data as data_utils
import matplotlib.pyplot as plt

from GAN import Generator, Discriminator

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
num_epochs = 5
batch_size = 64
D_lr = 8e-4
G_lr = 4e-4
data_path = '../data/full_numpy_bitmap_camel.npy'
image_save_folder = './images/gan/'
model_save_path = './gan_camel.pth'
os.makedirs(image_save_folder, exist_ok=True)

### 데이터 적재

In [None]:
class MyDataset(Dataset):
    def __init__(self, data):
        super(MyDataset, self).__init__()
        self.data = data
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx]

In [None]:

# data url: https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap?pli=1
train_data = np.load(data_path)
train_data = (train_data.astype('float32') - 127.5) / 127.5
train_data = train_data.reshape(-1, 1, 28, 28)
np.random.seed(np.random.randint(1, 10e6))
np.random.shuffle(train_data)
train_data = torch.from_numpy(train_data[:80000]).float()
dataset = MyDataset(train_data)
dataset_size = len(dataset)
dataloader = DataLoader(dataset,
                        shuffle=True,
                        batch_size=batch_size,
                        num_workers=4,
                        pin_memory=True,
                        drop_last=True)

In [None]:
print(train_data.shape)
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(train_data[:25],
                                         nrow=5,
                                         padding=2,
                                         normalize=True), (1, 2, 0)))
plt.savefig(os.path.join(image_save_folder, 'original_images.png'))

### 모델 만들기

In [None]:
D_model = Discriminator().to(device)
G_model = Generator().to(device)

In [None]:
print(G_model)

In [None]:
print(D_model)

### 모델 훈련

In [None]:
D_criterion = nn.BCELoss()
G_criterion = nn.BCELoss()

real_label = 1.
fake_label = 0.

D_optimizer = optim.RMSprop(params=D_model.parameters(), lr=D_lr)
G_optimizer = optim.RMSprop(params=G_model.parameters(), lr=G_lr)

In [None]:
G_losses = []
D_losses = []
D_losses_real = []
D_losses_fake = []
D_accs = []
D_accs_real = []
D_accs_fake = []

real_labels = torch.full((batch_size, ), real_label,
                         dtype=torch.float, device=device, requires_grad=False)
fake_labels = torch.full((batch_size, ), fake_label,
                         dtype=torch.float, device=device, requires_grad=False)

iter = 0

for epoch in range(num_epochs):
    for inputs in dataloader:
        D_model.zero_grad()
        
        # Discriminator
        # 진짜 데이터로 학습
        inputs = inputs.to(device)
        cur_batch_size = inputs.size(0)
        output = D_model(inputs).view(-1)
        D_acc_real = torch.mean(output)
        D_loss_real = D_criterion(output, real_labels)
        
        # 가짜 데이터로 학습
        noise = torch.randn(cur_batch_size, 100, device=device)
        fake = G_model(noise)
        
        output = D_model(fake.detach()).view(-1)
        D_acc_fake = 1 - torch.mean(output)
        D_loss_fake = D_criterion(output, fake_labels)
        
        D_acc = (D_acc_real + D_acc_fake) * 0.5
        D_loss = (D_loss_real + D_loss_fake) * 0.5
        D_loss.backward()
        D_optimizer.step()
        
        # Generator
        # 업데이트한 discriminator 사용
        # discriminator가 fake를 real이라고 판단하면 낮은 loss
        # fake를 fake라고 판단하면 높은 loss
        G_model.zero_grad()
        output = D_model(fake).view(-1)
        G_loss = G_criterion(output, real_labels)
        G_loss.backward()
        G_optimizer.step()
        
        print('iter: %d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f]' %
            (iter, D_loss.item(), D_loss_real.item(), D_loss_fake.item(), D_acc.item(), D_acc_real.item(), D_acc_fake.item(), G_loss.item()))
        
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())
        D_losses_real.append(D_loss_real.item())
        D_losses_fake.append(D_loss_fake.item())
        D_accs.append(D_acc.item())
        D_accs_real.append(D_acc_real.item())
        D_accs_fake.append(D_acc_fake.item())
    
        # print('%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f]' %
        #         (epoch, epoch_D_loss, epoch_D_loss_real, epoch_D_loss_fake, epoch_D_acc, epoch_D_acc_real, epoch_D_acc_fake, epoch_G_loss))

        iter += 1

    plt.figure(figsize=(20, 20))
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(fake[:25],
                                                nrow=5,
                                                padding=2,
                                                normalize=True).detach().cpu(), (1, 2, 0)))
    plt.savefig(os.path.join(image_save_folder, f'epoch_{epoch + 1}.png'))
        
torch.save({'Discriminator': D_model.state_dict(), 'Generator': G_model.state_dict(),}, model_save_path)
        

In [None]:
fig = plt.figure()
plt.plot([x for x in D_losses], color='black', linewidth=0.1)
plt.plot([x for x in D_losses_real], color='green', linewidth=0.1)
plt.plot([x for x in D_losses_fake], color='red', linewidth=0.1)
plt.plot([x for x in G_losses], color='orange', linewidth=0.1)

plt.xlabel('epoch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)

plt.savefig(os.path.join(image_save_folder, 'loss_graph.png'))

In [None]:
fig = plt.figure()
plt.plot([x for x in D_accs], color='black', linewidth=0.1)
plt.plot([x for x in D_accs_real], color='green', linewidth=0.1)
plt.plot([x for x in D_accs_fake], color='red', linewidth=0.1)

plt.xlabel('batch', fontsize=18)
plt.ylabel('accuracy', fontsize=16)

plt.xlim(0, 2000)

plt.savefig(os.path.join(image_save_folder, 'accuracy_graph.png'))

In [None]:
def l1_compare_images(img1, img2):
    return np.mean(np.abs(img1 - img2))

In [None]:
# load
# loaded_models = torch.load(model_save_path, map_location=device)
# D_model.load_state_dict(loaded_models['Discriminator'])
# G_model.load_state_dict(loaded_models['Generator'])

In [None]:
D_model.eval()
G_model.eval()

real_images = train_data[:25].numpy()
fake_images = []

for inputs in dataloader:
    inputs = inputs.to(device)
    b_size = inputs.size(0)
    
    with torch.no_grad():
        noise = torch.randn(b_size, 100, device=device)
        fake = G_model(noise)
        for fake_image in fake:
            fake_images.append(fake_image.detach().cpu().numpy())
    
similar_images = np.zeros(real_images.shape)
for i, real_image in enumerate(real_images):
    min_val = l1_compare_images(real_image, fake_images[0])
    similar_image = fake_images[0]
    for fake_image in fake_images[1:]:
        l1_dist = l1_compare_images(real_image, fake_image)
        if l1_dist < min_val:
            min_val = l1_dist
            similar_image = fake_image
            
    similar_images[i] = similar_image
    
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(torch.from_numpy(similar_images),
                                         nrow=5,
                                         padding=2,
                                         normalize=True), (1, 2, 0)))
plt.savefig(os.path.join(image_save_folder, 'similar_images.png'))

In [None]:
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(train_data[:25],
                                         nrow=5,
                                         padding=2,
                                         normalize=True), (1, 2, 0)))
plt.savefig(os.path.join(image_save_folder, 'original_images.png'))