In [1]:
from PIL import Image
from os import path, listdir
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-g3w2iaek because the default path (/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [6]:
import tensorflow as tf
from tensorflow.keras.layers import Activation, Conv2D, Conv2DTranspose, Dense, MaxPool2D, LeakyReLU, \
BatchNormalization, Dropout, Reshape, Flatten, RepeatVector, Add, ReLU, GlobalAveragePooling2D, AveragePooling2D, \
UpSampling2D, Layer
from tensorflow.keras import Model, Input, Sequential
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.constraints import  max_norm
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import backend

In [9]:
FOLDER = 'train_resized'
PROGRESS_FOLDER = 'progress'
TARGET_SHAPE = (28, 28, 3)
RANDOM_FEATURES_SIZE = 98
FILTERS = 128
SIZE = 4
EPOCHS = 130
BATCH_SIZE = 256
STEPS_PER_EPOCH = 500
SHOW_EVERY = 1

In [10]:
train = list(map(lambda x: path.join(FOLDER, x), listdir(FOLDER)))

In [8]:
class PixelNorm(Layer):
    def __init__(self, **kwargs):
        super(PixelNorm, self).__init__(**kwargs)
        
        
    def call(self, data):
        squared = data**2
        mean = backend.mean(squared, axis=-1, keepdims=True)
        norm = backend.sqrt(mean + 1.0e-8)
        return data / norm
    
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [9]:
class WeightedSum(Add):
    def __init__(self, alpha=0.0, **kwargs):
        super(WeightedSum, self).__init__(**kwargs)
        self.alpha = backend.variable(alpha, name='WS_alpha')
        
    
    def _blend(self, data):
        assert (len(data) == 2)
        blended = (1.0 - self.alpha) * data[0] + self.alpha * data[1]
        return blended

In [11]:
opt = Adam(0.0002, 0.5)

In [12]:
### GENERATOR
# (SHAPE - KERNEL + 2PADDING) / STRIDE + 1
# (SHAPE - 1) * STRIDE + KERNEL - 2PADDING

generator_input = Input(RANDOM_FEATURES_SIZE)

dropout = 0.4
depth = 64+64+64+64
dim = 7
# In: 100
# Out: dim x dim x depth
X = Dense(dim*dim*depth)(generator_input)
X = BatchNormalization(momentum=0.9)(X)
X = Activation('relu')(X)
X = Reshape((dim, dim, depth))(X)
X = Dropout(dropout)(X)

X = UpSampling2D()(X)
X = Conv2DTranspose(int(depth/2), 5, padding='same')(X)
X = BatchNormalization(momentum=0.9)(X)
X = Activation('relu')(X)

X = UpSampling2D()(X)
X = Conv2DTranspose(int(depth/4), 5, padding='same')(X)
X = BatchNormalization(momentum=0.9)(X)
X = Activation('relu')(X)

X = Conv2DTranspose(int(depth/8), 5, padding='same')(X)
X = BatchNormalization(momentum=0.9)(X)
X = Activation('relu')(X)

X = Conv2DTranspose(3, 5, padding='same')(X)
X = Activation('tanh')(X)

print(X.shape)
assert X.shape[1] == TARGET_SHAPE[0]

generator = Model(generator_input, X)
#generator.summary()
#generator.compile(opt, 'binary_crossentropy')

(None, 28, 28, 3)


In [13]:
### DISCRIMINATOR

discriminator_input = Input(TARGET_SHAPE)

depth = 64
dropout = 0.4

X = Conv2D(depth*1, 5, strides=2, padding='same')(discriminator_input)
X = LeakyReLU(alpha=0.2)(X)
X = Dropout(dropout)(X)

X = Conv2D(depth*2, 5, strides=2, padding='same')(X)
X = LeakyReLU(alpha=0.2)(X)
X = Dropout(dropout)(X)

X = Conv2D(depth*4, 5, strides=2, padding='same')(X)
X = LeakyReLU(alpha=0.2)(X)
X = Dropout(dropout)(X)

X = Conv2D(depth*8, 5, strides=1, padding='same')(X)
X = LeakyReLU(alpha=0.2)(X)
X = Dropout(dropout)(X)

X = Flatten()(X)
X = Dense(1)(X)
X = Activation('sigmoid')(X)

print(X.shape)

discriminator = Model(discriminator_input, X)
"""
### LETS TRY EFFICIENT NET B0 WITH RANDOM WEIGHTS

discriminator = Sequential()

discriminator.add(EfficientNetB0(include_top=False, input_tensor=discriminator_input, pooling='avg'))
discriminator.add(Dense(1, activation='sigmoid'))
"""
#discriminator.summary()
discriminator.compile(opt, 'binary_crossentropy', ['accuracy'])

(None, 1)


In [14]:
discriminator.trainable = False

gan_input = Input(RANDOM_FEATURES_SIZE)

generated_img = generator(gan_input)

gan_out = discriminator(generated_img)

gan = Model(gan_input, gan_out)
#gan.summary()
gan.compile(opt, 'binary_crossentropy')

In [17]:
def load_images(count):
    images = list()
    for choiced in np.random.choice(train, size=count, replace=False):
        images.append(np.load(choiced))
    return np.array(images)

In [26]:
### TRAIN?

demo_noise = np.random.normal(0, 1, (9, RANDOM_FEATURES_SIZE))

d_y = np.zeros((2*BATCH_SIZE, 1))
d_y[:BATCH_SIZE, :] = 1
g_y = np.ones((BATCH_SIZE, 1))

for epoch in range(EPOCHS):
    for step in tqdm(range(STEPS_PER_EPOCH)):
        noise = np.random.normal(0, 1, (BATCH_SIZE, RANDOM_FEATURES_SIZE))
        fake_imgs = generator.predict(noise)
        true_imgs = load_images(BATCH_SIZE)
        
        x = np.concatenate((true_imgs, fake_imgs))
        
        d_loss = discriminator.train_on_batch(x, d_y)
        
        noise = np.random.normal(0, 1, (BATCH_SIZE, RANDOM_FEATURES_SIZE))
        g_loss = gan.train_on_batch(noise, g_y)
    print(f'Epoch: {epoch}; Discriminator loss: {d_loss[0]}; GAN loss: {g_loss}')#'; D_acc: {d_loss[1]}')
    if SHOW_EVERY:
        if not epoch % SHOW_EVERY:
            plt.figure(figsize=(16, 16))
            preds = generator.predict(demo_noise)
            preds = (preds - preds.min()) / (preds.max() - preds.min())
            for i, pred in enumerate(preds):
                plt.subplot(3, 3, i+1)
                plt.imshow(pred*0.5+0.5, 'gray')
                plt.axis('off')
            plt.tight_layout()
            plt.savefig(path.join(PROGRESS_FOLDER, f'EPOCH_{epoch}.png'))
            #plt.show()
            plt.close('all')

100%|██████████| 500/500 [04:58<00:00,  1.68it/s]


Epoch: 0; Discriminator loss: 0.7051420211791992; GAN loss: 2.9062678813934326


100%|██████████| 500/500 [04:50<00:00,  1.72it/s]


Epoch: 1; Discriminator loss: 0.08527487516403198; GAN loss: 0.19073942303657532


100%|██████████| 500/500 [04:49<00:00,  1.73it/s]


Epoch: 2; Discriminator loss: 0.24089208245277405; GAN loss: 0.5318255424499512


100%|██████████| 500/500 [04:48<00:00,  1.73it/s]


Epoch: 3; Discriminator loss: 0.47159886360168457; GAN loss: 0.5887347459793091


100%|██████████| 500/500 [04:56<00:00,  1.69it/s]


Epoch: 4; Discriminator loss: 0.44175827503204346; GAN loss: 0.8627838492393494


100%|██████████| 500/500 [05:07<00:00,  1.63it/s]


Epoch: 5; Discriminator loss: 0.42020270228385925; GAN loss: 1.6112477779388428


100%|██████████| 500/500 [05:03<00:00,  1.65it/s]


Epoch: 6; Discriminator loss: 0.48684796690940857; GAN loss: 1.1800661087036133


100%|██████████| 500/500 [04:52<00:00,  1.71it/s]


Epoch: 7; Discriminator loss: 0.3895328938961029; GAN loss: 0.6036587357521057


 45%|████▌     | 227/500 [02:09<02:35,  1.76it/s]


KeyboardInterrupt: 