### 使用昨天的卡通資料集來練GAN

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import random
from PIL import Image
from keras.models import Model
from keras.optimizers import Adam
from keras.layers import Conv2D, UpSampling2D, Dense, Flatten, Input, BatchNormalization, Reshape, LeakyReLU, Conv2DTranspose, Dropout

### 讀資料

In [2]:
class DataLoader:
    def __init__(self, folder_path, img_size):
        self.folder_path = folder_path
        self.img_size = img_size
        
        self.path_list = glob(folder_path) # 讀取資料夾全部圖片路徑
        assert len(self.path_list) > 0, 'path not existed!'
    
    def __imread(self, img_path):
        '''讀取圖片'''
        return np.array(Image.open(img_path).convert('RGB').resize(self.img_size[:-1], Image.ANTIALIAS))
    
    def sampling_data(self, batch_size, shuffle=True):
        img_path_list = self.path_list
        
        if shuffle:
            random.shuffle(img_path_list)
            
        for batch_idx in range(0, len(img_path_list), batch_size):
            path_set = img_path_list[batch_idx : batch_idx + batch_size]
            
            # 預設空間，避免 append很慢
            img_set = np.zeros((len(path_set),) + self.img_size)
            for img_idx, path in enumerate(path_set):
                img_set[img_idx] = self.__imread(path)
            
            # 127.5是255的一半，一到負一之間
            img_set = img_set / 127.5 - 1
            # 暫停
            yield img_set

### 建model

In [9]:
class GAN:
    def __init__(self, noise_dim, img_size=(64, 64, 3)):
        self.noise_dim = noise_dim # noise_dim = 雜訊維度
        self.img_size = img_size # img_size = 圖片大小
        self.dataloader = DataLoader('../0709/Preview/cartoon/*.png', self.img_size)
        
    def build_generator(self):
        noise_input = Input(shape=(self.noise_dim,))
        
        x = Dense(128)(noise_input)
        x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(.2)(x)
        
        x = Dense(128)(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(.2)(x)
        
        x = Dense(64)(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(.2)(x)
        
        x = Dense(64*64*3, activation='tanh')(x)
        
        img = Reshape((64, 64, 3))(x)

        generator = Model(noise_input, img)
#         generator.summary()
        return generator

    def build_discriminator(self):
        img_input = Input(shape=self.img_size)
        # 把照片拉直好放入dense
        x = Flatten()(img_input)
        
        x = Dense(32)(x)
        # 加了BatchNormalization會導致訓練失敗
#         x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(.2)(x)
        x = Dropout(0.5)(x)
        
        x = Dense(32)(x)
#         x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(.2)(x)
        x = Dropout(0.5)(x)

        # 重要的技巧（新增一個dropout層）
#         x = Dropout(0.4)(x)

        # 分類層
        validity = Dense(1, activation='sigmoid')(x)

        discriminator = Model(img_input, validity)
#         discriminator.summary()
        return discriminator

    def connect(self):
        self.generator = self.build_generator()
        print(self.generator.count_params())
        self.discriminator = self.build_discriminator()
        print(self.discriminator.count_params())
        self.optimizer = Adam(.0002, .5)
        # Optimizer用Adam, Learning rate=0.0001~0.0002, 切勿調高
        self.discriminator.compile(optimizer=self.optimizer, loss='binary_crossentropy', metrics=['acc'])
        
        noise = Input(shape=(self.noise_dim,))
        img = self.generator(noise)
        self.discriminator.trainable = False # 在訓練G時, 鎖定D
        validity = self.discriminator(img)

        self.combined = Model(noise, validity)
        self.combined.compile(optimizer=self.optimizer, loss='binary_crossentropy')

    def train(self, epochs, batch_size, sample_interval=200):
        self.history = []
        valid = np.ones((batch_size, 1)) # 1 = 真實圖片
        fake = np.zeros((batch_size, 1)) # 0 = 生成圖片

        for e in range(epochs):
            for i, real_img in enumerate(self.dataloader.sampling_data(batch_size)):
                # Train D
                noise = np.random.standard_normal((batch_size, self.noise_dim))
                fake_img = self.generator.predict(noise)

                d_loss_real, real_acc = self.discriminator.train_on_batch(real_img, valid[:len(real_img)])
                d_loss_fake, fake_acc = self.discriminator.train_on_batch(fake_img, fake)
                d_loss = .5 * (d_loss_real + d_loss_fake)
                d_acc = .5 * (real_acc + fake_acc)
                                                                          
                # Train G
                noise = np.random.standard_normal((batch_size, self.noise_dim))
                g_loss = self.combined.train_on_batch(noise, valid)

                if i % sample_interval == 0:
                    info = {
                            'epoch': e,
                            'iter': i,
                            'd_loss': d_loss,
                            'd_acc': d_acc*100,
                            'g_loss': g_loss
                            }
                    self.history.append(list(info.values()))
                    print('[Epoch %(epoch)d][Iteration %(iter)d][D loss: %(d_loss).6f, acc: %(d_acc).2f%%][G loss: %(g_loss).6f]' % info)
            self.__sample_image(e)
        return self.history

    def __sample_image(self, epoch):
        r, c = 8, 8 # 列, 欄
        noise = np.random.standard_normal((r*c, self.noise_dim))
        img = self.generator.predict(noise).reshape((r, c) + self.img_size)
        img = img * .5 + .5
        fig = plt.figure(figsize=(20, 20))
        axs = fig.subplots(r, c)
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(img[i, j])
                axs[i, j].axis('off')
        fig.savefig('./Image/%5d.png' % epoch)
        plt.close()

In [11]:
%%time
gan = GAN(128, img_size=(64, 64, 3))
gan.connect()
gan.train(20, 64, sample_interval=10)

1652608
394337
[Epoch 0][Iteration 0][D loss: 0.581307, acc: 60.94%][G loss: 0.742023]


  'Discrepancy between trainable weights and collected trainable'


[Epoch 0][Iteration 10][D loss: 1.162078, acc: 52.34%][G loss: 0.252068]
[Epoch 0][Iteration 20][D loss: 0.722191, acc: 50.00%][G loss: 0.392950]
[Epoch 0][Iteration 30][D loss: 0.632635, acc: 64.06%][G loss: 0.691841]
[Epoch 0][Iteration 40][D loss: 0.542671, acc: 65.62%][G loss: 0.804728]
[Epoch 0][Iteration 50][D loss: 0.497347, acc: 64.84%][G loss: 0.798103]
[Epoch 0][Iteration 60][D loss: 0.403466, acc: 74.22%][G loss: 0.866657]
[Epoch 0][Iteration 70][D loss: 0.457039, acc: 69.53%][G loss: 0.771101]
[Epoch 0][Iteration 80][D loss: 0.695810, acc: 56.25%][G loss: 0.723735]
[Epoch 0][Iteration 90][D loss: 0.604644, acc: 57.03%][G loss: 0.699438]
[Epoch 0][Iteration 100][D loss: 0.723906, acc: 53.12%][G loss: 0.736585]
[Epoch 0][Iteration 110][D loss: 0.696265, acc: 51.56%][G loss: 0.675631]
[Epoch 0][Iteration 120][D loss: 0.794337, acc: 50.78%][G loss: 0.670343]
[Epoch 0][Iteration 130][D loss: 0.654213, acc: 55.47%][G loss: 0.519264]
[Epoch 0][Iteration 140][D loss: 0.669023, acc: