# 再現整個 Fashion MNIST 訓練過程

In [None]:
from dotenv import load_dotenv
_ = load_dotenv()

%reload_ext autoreload
%autoreload 2

In [None]:
from practical_ai.data import get_dataset, get_data_loader

## 引入已實現好的模組

In [None]:
import sys

lib_path = "/home/ec2-user/SageMaker/workspace/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch"
sys.path = [lib_path] + sys.path

In [None]:
import os
import torch
from easydict import EasyDict as edict

## 參數

In [None]:
args = edict(
    dataset="mnist",
    batch_size=32,
    gradient_penalty_mode='none',
    adversarial_loss_mode='gan',
    z_dim=128,
    lr=0.0002,
    beta_1=0.5,
)

In [None]:
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")
use_gpu, device

(True, device(type='cuda'))

In [None]:
n_G_upsamplings = n_D_downsamplings = 3

## 數據

In [None]:
dataset = get_dataset("mnist")
data_loader = get_data_loader(dataset, batch_size=args.batch_size, 
                                pin_memory=use_gpu)

INFO:root:MNIST will be resized to (28, 28).


In [None]:
dataset.input_shape

(28, 28, 1)

## 模型

In [None]:
import module
import torchprob as gan

In [None]:
# setup the normalization function for discriminator
if args.gradient_penalty_mode == 'none':
    d_norm = 'batch_norm'
else:  # cannot use batch normalization with gradient penalty
    d_norm = args.gradient_penalty_d_norm

# networks

G = module.ConvGenerator(args.z_dim, 3, n_upsamplings=3).to(device)
D = module.ConvDiscriminator(dataset.input_shape[-1], n_downsamplings=3, norm=d_norm).to(device)
# print(G)
print(D)

# # adversarial_loss_functions
# d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(args.adversarial_loss_mode)

# # optimizer
# G_optimizer = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta_1, 0.999))
# D_optimizer = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta_1, 0.999))

ConvDiscriminator(
  (net): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)


# 改成 Pytorch Lightning 格式
- 加上 TensorBoard 

# 訓練李宏毅的 Anime

# Future

In [None]:
# general
import os
import numpy as np
from argparse import ArgumentParser
from collections import OrderedDict

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, input_shape):
        super().__init__()
        self.latent_dim = latent_dim
        self.input_shape = input_shape

#         def block(in_feat, out_feat, normalize=True):
#             layers = [nn.Linear(in_feat, out_feat)]
#             if normalize:
#                 layers.append(nn.BatchNorm1d(out_feat, 0.8))
#             layers.append(nn.LeakyReLU(0.2, inplace=True))
#             return layers

#         self.model = nn.Sequential(
#             *block(latent_dim, 128, normalize=False),
#             *block(128, 256),
#             *block(256, 512),
#             *block(512, 1024),
#             *block(1024, np.prod(img_shape)),
#         )

#         self.model

    
    def forward(self, z):
        return self.model(z)
        
#         img = self.model(z)
#         img = img.view(img.size(0), *self.img_shape)
#         return img


In [None]:
latent_dim = 50
input_shape = (1, 28, 28)


g = Generator(latent_dim=latent_dim, input_shape=input_shape)

In [None]:
base_channels = 1024
base_output_shape = 4

model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=512, k)
)

In [None]:
4 -> 8 -> 16 -> 28
512 -> 256 -> 128 -> 3

In [None]:
bs = 3
z = torch.randn(bs, 100)
out = g(z)
out.size()

torch.Size([3, 1, 28, 28])

In [None]:
g.model

Sequential(
  (0): Linear(in_features=100, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Linear(in_features=128, out_features=256, bias=True)
  (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (4): LeakyReLU(negative_slope=0.2, inplace=True)
  (5): Linear(in_features=256, out_features=512, bias=True)
  (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (7): LeakyReLU(negative_slope=0.2, inplace=True)
  (8): Linear(in_features=512, out_features=1024, bias=True)
  (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (10): LeakyReLU(negative_slope=0.2, inplace=True)
  (11): Linear(in_features=1024, out_features=784, bias=True)
  (12): BatchNorm1d(784, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (13): LeakyReLU(negative_slope=0.2, inplace=True)
)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, img):
        validity = self.model(img.view(img.size(0), -1))
        return validity

In [None]:
d = Discriminator(img_shape)
d(out)

tensor([[0.4853],
        [0.4854],
        [0.4855]], grad_fn=<SigmoidBackward>)

## GAN

In [None]:
class GAN(pl.LightningModule):
    def __init__(self, gan_type, hparams):
        super(GAN, self).__init__()
        self.hparams = hparams
        
        # networks
        self.gan_type = gan_type
        if self.gan_type == "gan":
            self.generator = Generator(hparams.latent_dim, )
            self.discriminator = Discriminator(mnist_shape)
        else:
            raise NotImplementedError
        
        
        self.hparams = hparams
        self.data_root = data_root = os.path.join(os.getcwd(), "data")

        # network
        mnist_shape = (1, 28, 28)
        

        # cache
        self.generated_imgs = None
        self.last_imgs = None
    
    def forward(self, z):
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch
        self.last_imgs = imgs

        # train generator
        if optimizer_idx == 0:
            # sample noise
            z = torch.randn(imgs.shape[0], self.hparams.latent_dim)

            if self.on_gpu:
                z = z.cuda(imgs.device.index)

            # generate images
            self.generated_imgs = self.forward(z)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            if self.on_gpu:
                valid = valid.cuda(imgs.device.index)
            
            g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.shape[0], 1)
            if self.on_gpu:
                valid = valid.cuda(imgs.device.index)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.shape[0], 1)
            if self.on_gpu:
                fake = fake.cuda(imgs.device.index)
            
            fake_loss = self.adversarial_loss(
                self.discriminator(self.generated_imgs.detach()), fake)
            
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict,
            })
            return output
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def prepare_data(self):
        # prepare transforms standard to MNIST
        
        if not os.path.exists(self.data_root):
            os.makedirs(self.data_root)
        MNIST(self.data_root, train=True, download=True, transform=transforms.ToTensor())

    def train_dataloader(self):
        transform=transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.1307,), (0.3081,))])
        
        dataset = MNIST(self.data_root, train=True, download=False, transform=transform)
        self.train_dataset, self.val_dataset = random_split(dataset, [55000, 5000])

        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size)
    
    def on_epoch_end(self):

        z = torch.randn(self.last_imgs.shape[0], self.hparams.latent_dim)
        z = z[:32]
        if self.on_gpu:
            z = z.cuda(self.last_imgs.device.index)

        # log sampled images
        sample_imgs = self.forward(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch + 1)

        

## Training

In [None]:
# !rm -rf lightning_logs/

In [None]:
from argparse import Namespace

args = {
    'batch_size': 1024,
    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
    'latent_dim': LATENT_DIM
}
hparams = Namespace(**args)
hparams

Namespace(b1=0.5, b2=0.999, batch_size=1024, latent_dim=100, lr=0.0002)

In [None]:
model = GAN(hparams, generator=)

trainer = pl.Trainer(gpus=1, log_gpu_memory=True)   

trainer.fit(model)   

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …




1