In [None]:
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
import torch.utils.data as data_utils
import matplotlib.pyplot as plt

from GAN import Generator, Discriminator

In [None]:
num_epochs = 6000
batch_size = 64
d_lr = 8e-4
g_lr = 4e-4
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
image_save_path = './images/camel/'
os.makedirs(image_save_path, exist_ok=True)

In [None]:
class MyDataset(Dataset):
    def __init__(self, data):
        super(MyDataset, self).__init__()
        self.data = data
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
train_data = np.load('./data/full_numpy_bitmap_camel.npy')
train_data = train_data.reshape(-1, 1, 28, 28)
np.random.seed(np.random.randint(1, 10e6))
np.random.shuffle(train_data)
train_data = torch.from_numpy(train_data[:80000]).float()
dataset = MyDataset(train_data)
dataloader = DataLoader(dataset,
                        shuffle=True,
                        batch_size=batch_size,
                        num_workers=4,
                        pin_memory=True)

In [None]:
print(train_data.shape)
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(train_data[:25],
                                         nrow=5,
                                         padding=2,
                                         normalize=True), (1, 2, 0)))
plt.savefig(os.path.join(image_save_path, 'original_images.png'))
#plt.imshow(train_data[0][0], cmap='gray')

In [None]:
d_model = Discriminator("cuda:0").to(device)
g_model = Generator("cuda:0").to(device)

In [None]:
print(g_model)

In [None]:
print(d_model)

In [None]:
d_criterion = nn.BCELoss()
g_criterion = nn.BCELoss()

real_label = 1.
fake_label = 0.

d_optimizer = optim.RMSprop(params=d_model.parameters(), lr=d_lr)
g_optimizer = optim.RMSprop(params=g_model.parameters(), lr=g_lr)

In [None]:
G_losses = []
D_losses = []
D_losses_real = []
D_losses_fake = []
D_accs = []
D_accs_real = []
D_accs_fake = []

for epoch in range(num_epochs):
    #for i, data in enumerate(dataloader, 0):
    data = next(iter(dataloader))
    d_model.zero_grad()
    
    # Discriminator
    # 진짜 데이터로 학습
    inputs = data.to(device)
    b_size = inputs.size(0)
    label = torch.full((b_size, ), real_label,
                        dtype=torch.float, device=device)
    output = d_model(inputs).view(-1)
    accD_real = torch.mean(output)
    errD_real = d_criterion(output, label)
    errD_real.backward()
    D_x = output.mean().item()
    
    # 가짜 데이터로 학습
    noise = torch.randn(b_size, 100, device=device)
    fake = g_model(noise)
    label.fill_(fake_label)
    
    output = d_model(fake.detach()).view(-1)
    accD_fake = 1 - torch.mean(output)
    errD_fake = d_criterion(output, label)
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    
    accD = (accD_real + accD_fake) * 0.5
    errD = (errD_real + errD_fake) * 0.5
    d_optimizer.step()
    
    # Generator
    # 업데이트한 discriminator 사용
    # discriminator가 fake를 real이라고 판단하면 낮은 loss
    # fake를 fake라고 판단하면 높은 loss
    g_model.zero_grad()
    label.fill_(real_label)
    output = d_model(fake).view(-1)
    errG = g_criterion(output, label)
    errG.backward()
    D_G_z2 = output.mean().item()
    g_optimizer.step()
        
    G_losses.append(errG.item())
    D_losses.append(errD.item())
    D_losses_real.append(errD_real)
    D_losses_fake.append(errD_fake)
    D_accs.append(accD.item())
    D_accs_real.append(accD_real)
    D_accs_fake.append(accD_fake)
    
    print('%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f]' %
              (epoch, errD, errD_real, errD_fake, accD, accD_real, accD_fake, errG))
        
    if epoch + 1 in [20, 200, 400, 1000, 2000]:
        plt.figure(figsize=(20, 20))
        plt.axis("off")
        plt.imshow(np.transpose(vutils.make_grid(fake[:25],
                                                 nrow=5,
                                                 padding=2,
                                                 normalize=True).detach().cpu(), (1, 2, 0)))
        plt.savefig(os.path.join(image_save_path, f'epoch_{epoch + 1}.png'))
        

In [None]:
fig = plt.figure()
plt.plot([x for x in D_losses], color='black', linewidth=0.25)
plt.plot([x for x in D_losses_real], color='green', linewidth=0.25)
plt.plot([x for x in D_losses_fake], color='red', linewidth=0.25)
plt.plot([x for x in G_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)
plt.ylim(0, 2)

plt.savefig(os.path.join(image_save_path, 'loss_graph.png'))

In [None]:
fig = plt.figure()
plt.plot([x for x in D_accs], color='black', linewidth=0.25)
plt.plot([x for x in D_accs_real], color='green', linewidth=0.25)
plt.plot([x for x in D_accs_fake], color='red', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('accuracy', fontsize=16)

plt.xlim(0, 2000)

plt.savefig(os.path.join(image_save_path, 'accuracy_graph.png'))