<a href="https://colab.research.google.com/github/onebottlekick/bhban_ai_pytorch/blob/main/CNN/gan/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import datetime
import glob
import os

import imageio
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid, save_image

In [2]:
def plot(generator, device, n_imgs=100, show=True, save_img=False, save_path='results'):
    z = torch.randn(n_imgs, 100, device=device)
    gen_imgs = generator(z)
    grid = make_grid(gen_imgs, nrow=10, pad_value=255)
    grid.clip_(0, 1)

    if show:
        plt.axis('off')
        plt.imshow(grid.cpu().permute(1, 2, 0), cmap='gray')
        plt.show()

    if save_img:
        os.makedirs(save_path, exist_ok=True)
        save_image(grid, os.path.join(save_path, f'{datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")}.png'))

In [3]:
def img2gif(root='results', to='result.gif', remove_imgs=False, duration=0.5):
    imgs = [np.array(Image.open(os.path.join(root, img))) for img in sorted(os.listdir(root)) if img.endswith('png')]    
    imageio.mimsave(os.path.join(root, to), imgs, duration=duration)
    
    if remove_imgs:
        del imgs
        for img in os.listdir(root):
            if not img.endswith(".gif"):
                os.remove(os.path.join(root, img))

In [4]:
def train(generator, discriminator, dataloader, g_optimizer, d_optimizer, criterion, device):
    g_losses = 0
    d_losses = 0
    for idx, (img, _) in enumerate(dataloader):
        real = img.to(device)
        
        valid = torch.ones(img.size(0), 1, requires_grad=False, device=device)
        fake = torch.zeros(img.size(0), 1, requires_grad=False, device=device)

        g_optimizer.zero_grad()
        z = torch.randn(img.shape[0], 100, device=device)
        gen = generator(z)
        g_loss = criterion(discriminator(gen), valid)
        g_loss.backward()
        g_optimizer.step()

        if idx%30 == 0:
            plot(generator, device, save_img=True)
        
        d_optimizer.zero_grad()
        real_loss = criterion(discriminator(real), valid)
        fake_loss = criterion(discriminator(gen.detach()), fake)
        d_loss = (real_loss + fake_loss)/2
        d_loss.backward()
        d_optimizer.step()
        
        g_losses += g_loss.item()
        d_losses += d_loss.item()
        
    return g_losses/len(dataloader), d_losses/len(dataloader)

In [5]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCHS = 5
BETAS = (0.5, 0.999)
LEARNING_RATE = 0.0001
BATCH_SIZE = 64
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor()
])

In [6]:
dataloader = DataLoader(
    dataset=MNIST(
                root='data',
                train=True,
                download=True,
                transform=transform
            ),
    shuffle=True,
    batch_size=BATCH_SIZE
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, img_size=32):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            *self._block(in_channels, 16, normalize=False),
            *self._block(16, 32),
            *self._block(32, 64),
            *self._block(64, 128)
        )
        
        self.fc = nn.Linear(128*(img_size//2**4)**2, 1)
        
        self.sigmoid = nn.Sigmoid()
        
    def _block(self, in_channels, out_channels, normalize=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.2)
        ]
        
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
            
        return layers
        
    def forward(self, img):
        x = self.model(img)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x
    
    
class Generator(nn.Module):
    def __init__(self, in_features=100, img_channels=1, img_size=32):
        super(Generator, self).__init__()
        self.img_size = img_size
        
        self.fc = nn.Linear(in_features, 128*((img_size//4)**2))
        
        self.model = nn.Sequential(
            nn.BatchNorm2d(128),
            *self._block(128, 128),
            *self._block(128, 64),
            nn.Conv2d(64, img_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
        
    def _block(self, in_channels, out_channels):
        layers = [
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        return layers
        
    def forward(self, z):
        z = self.fc(z)
        z = z.view(z.shape[0], 128, self.img_size//4, self.img_size//4)
        z = self.model(z)
        return z

In [8]:
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=BETAS)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=BETAS)

In [None]:
 for epoch in range(NUM_EPOCHS):
    g_loss, d_loss = train(generator, discriminator, dataloader, g_optimizer, d_optimizer, criterion, DEVICE)
    
    print()
    print('='*30)
    print(f'Epoch [{epoch+1:03}/{NUM_EPOCHS}] G_Loss: {g_loss:.4f}, D_Loss: {d_loss:.4f}')
    print('='*30)
    print()

In [10]:
img2gif(remove_imgs=True, duration=0.1)