In [1]:
# set gpu by number 
import os
import random
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # setting gpu number

In [2]:
# load packages
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import imageio #### install with "pip install imageio"
from IPython.display import HTML

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torchvision.utils import make_grid

In [3]:
# Create folders
if not os.path.exists('./checkpoint'):
    os.mkdir('./checkpoint')
    
if not os.path.exists('./dataset'):
    os.mkdir('./dataset')
    
if not os.path.exists('./img'):
    os.mkdir('./img')
    
if not os.path.exists('./img/real'):
    os.mkdir('./img/real')

if not os.path.exists('./img/fake'):
    os.mkdir('./img/fake')
    

In [4]:
# visualize the first image from the torch tensor
def vis_image(image):
    plt.imshow(image[0].detach().cpu().numpy(),cmap='gray')
    plt.show()

def save_gif(training_progress_images, images):
    '''
        training_progress_images: list of training images generated each iteration
        images: image that is generated in this iteration
    '''
    img_grid = make_grid(images.data)
    img_grid = np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0))
    img_grid = 255. * img_grid 
    img_grid = img_grid.astype(np.uint8)
    training_progress_images.append(img_grid)
    imageio.mimsave('./img/training_progress.gif', training_progress_images)
    return training_progress_images

# visualize gif file
def vis_gif(training_progress_images):
    fig = plt.figure()
    
    ims = []
    for i in range(len(training_progress_images)):
        im = plt.imshow(training_progress_images[i], animated=True)
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
    
    html = ani.to_html5_video()
    HTML(html)

# visualize gif file
def plot_gif(training_progress_images, plot_length=10):
    plt.close()
    fig = plt.figure()
    
    total_len = len(training_progress_images)
    for i in range(plot_length):
        im = plt.imshow(training_progress_images[int(total_len/plot_length)*i])
        plt.show()

def save_image_list(dataset, real):
    if real:
        base_path = './img/real'
    else:
        base_path = './img/fake'
    
    dataset_path = []
    
    for i in range(len(dataset)):
        save_path =  f'{base_path}/image_{i}.png'
        dataset_path.append(save_path)
        vutils.save_image(dataset[i], save_path)
    
    return base_path

In [5]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x : 2 * x - 1) # normalize to be in [-1, 1]
            ])

In [6]:
dataset = datasets.FashionMNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



In [13]:
BATCH_SIZE = 100
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=BATCH_SIZE, 
                                          shuffle=True,
                                          num_workers=3)

  cpuset_checked))


In [7]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(64, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 28 * 28),
            nn.Tanh()
        )

    def forward(self, input):   
        output = self.main(input)
        output = output.view(-1,1,28,28)
        return output

In [14]:
n_batches = int(np.ceil(len(dataset)/ BATCH_SIZE)) # 60000 / 100

In [15]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        input = input.view(-1,28*28)
        output = self.main(input)
        output = output.squeeze(dim=1)
        return output

In [16]:
netG = Generator().cuda() # .cuda => upload to GPU
netD = Discriminator().cuda()

optimizerD = optim.Adam(netD.parameters(), lr=0.0002)
optimizerG = optim.Adam(netG.parameters(), lr=0.0002)

In [17]:
#### Implement here ####
noise = torch.randn(128,100).cuda() # cuda 왜 붙이나?

In [None]:
fixed_noise = torch.randn(128, 64).cuda() # 왜 붙이나?

criterion = nn.BCELoss()

n_epoch = 200
training_progress_images_list = []
for epoch in range(n_epoch):
    for i, (data, _) in enumerate(data_loader):
        ####################################################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) #
        ###################################################
        # train with real
        netD.zero_grad()
        data = data.cuda()
        batch_size = data.size(0)
        
        label = torch.ones((batch_size,)).cuda() # real label 
        output = netD(data)
        errD_real = criterion(output,label)
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size,64).cuda()
        fake = netG(noise)
        label = torch.zeros((batch_size,)).cuda() # fake label
        output = netD(fake.detach())
        errD_fake = criterion(output,label)
        D_G_z1 = output.mean().item()

        # Loss backward
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()

        ########################################
        # (2) Update G network: maximize log(D(G(z))) #
        ########################################
        netG.zero_grad()
        label = torch.ones((batch_size,)).cuda() # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output,label)
        D_G_z2 = output.mean().item()

        errG.backward()
        optimizerG.step()
        
    print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 
              % (epoch, n_epoch, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
    #save the output
    fake = netG(fixed_noise)
    training_progress_images_list = save_gif(training_progress_images_list, fake)  # Save fake image while training!
    
    # Check pointing for every epoch
    torch.save(netG.state_dict(), './checkpoint/netG_epoch_%d.pth' % (epoch))
    torch.save(netD.state_dict(), './checkpoint/netD_epoch_%d.pth' % (epoch))

  cpuset_checked))


[0/200] Loss_D: 0.0765 Loss_G: 5.8945 D(x): 0.9886 D(G(z)): 0.0261 / 0.0198
[1/200] Loss_D: 0.4365 Loss_G: 3.4011 D(x): 0.9034 D(G(z)): 0.1348 / 0.0435
[2/200] Loss_D: 0.1532 Loss_G: 5.1542 D(x): 0.9555 D(G(z)): 0.0166 / 0.0077
[3/200] Loss_D: 0.0661 Loss_G: 5.1520 D(x): 0.9973 D(G(z)): 0.0521 / 0.0134
[4/200] Loss_D: 0.1727 Loss_G: 4.2334 D(x): 0.9746 D(G(z)): 0.0932 / 0.0285
[5/200] Loss_D: 0.5223 Loss_G: 4.1695 D(x): 0.8259 D(G(z)): 0.0524 / 0.0463
[6/200] Loss_D: 0.2502 Loss_G: 3.5737 D(x): 0.9181 D(G(z)): 0.0654 / 0.0480
[7/200] Loss_D: 0.4260 Loss_G: 3.1572 D(x): 0.8758 D(G(z)): 0.1034 / 0.0839
[8/200] Loss_D: 0.4663 Loss_G: 2.5721 D(x): 0.9233 D(G(z)): 0.2026 / 0.1111
[9/200] Loss_D: 0.2343 Loss_G: 3.1489 D(x): 0.9306 D(G(z)): 0.1102 / 0.0762
[10/200] Loss_D: 0.5125 Loss_G: 2.7035 D(x): 0.8274 D(G(z)): 0.1184 / 0.0953
[11/200] Loss_D: 0.5828 Loss_G: 2.7066 D(x): 0.8298 D(G(z)): 0.1714 / 0.0945
[12/200] Loss_D: 0.7571 Loss_G: 2.9407 D(x): 0.7796 D(G(z)): 0.1028 / 0.0992
[13/200] 