In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend
from tensorflow.keras.callbacks import EarlyStopping

from matplotlib import pyplot
from math import sqrt
from PIL import Image
import os
from models import *

In [None]:
FILTERS = [512, 512, 512, 512, 256, 128, 64]

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # gpus[#] : 사용하고자 하는 GPU Num
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    except RuntimeError as e:
        print(e)

In [None]:
# Segmentation networks Load
G1 = models.load_model('g1.h5')
G2 = models.load_model('g2.h5')

In [None]:
def create_mask(pred_mask, n_clusters):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    pred_mask /= (n_clusters-1)
    
    return pred_mask

In [None]:
class PixelNormalization(Layer):
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)

    def call(self, inputs):
        mean_square = tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True)
        l2 = tf.math.rsqrt(mean_square + 1.0e-8)
        normalized = inputs * l2
        return normalized

    def compute_output_shape(self, input_shape):
        return input_shape
    
class MinibatchStdev(Layer):
    def __init__(self, **kwargs):
        super(MinibatchStdev, self).__init__(**kwargs)
    
    def call(self, inputs):
        mean = tf.reduce_mean(inputs, axis=0, keepdims=True)
        stddev = tf.sqrt(tf.reduce_mean(tf.square(inputs - mean), axis=0, keepdims=True) + 1e-8)
        average_stddev = tf.reduce_mean(stddev, keepdims=True)
        shape = tf.shape(inputs)
        minibatch_stddev = tf.tile(average_stddev, (shape[0], shape[1], shape[2], 1))
        combined = tf.concat([inputs, minibatch_stddev], axis=-1)
        
        return combined
    
    def compute_output_shape(self, input_shape):
        input_shape = list(input_shape)
        input_shape[-1] += 1
        return tuple(input_shape)

class WeightedSum(Add):
    def __init__(self, alpha=0.0, **kwargs):
        super(WeightedSum, self).__init__(**kwargs)
        self.alpha = backend.variable(alpha, name='ws_alpha')
    
    def _merge_function(self, inputs):
        assert (len(inputs) == 2)
        output = ((1.0 - self.alpha) * inputs[0] + (self.alpha * inputs[1]))
        return output

class WeightScaling(Layer):
    def __init__(self, shape, gain = np.sqrt(2), **kwargs):
        super(WeightScaling, self).__init__(**kwargs)
        shape = np.asarray(shape)
        shape = tf.constant(shape, dtype=tf.float32)
        fan_in = tf.math.reduce_prod(shape)
        self.wscale = gain*tf.math.rsqrt(fan_in)
      
    def call(self, inputs, **kwargs):
        inputs = tf.cast(inputs, tf.float32)
        return inputs * self.wscale
    
    def compute_output_shape(self, input_shape):
        return input_shape

class Bias(Layer):
    def __init__(self, **kwargs):
        super(Bias, self).__init__(**kwargs)

    def build(self, input_shape):
        b_init = tf.zeros_initializer()
        self.bias = tf.Variable(initial_value = b_init(shape=(input_shape[-1],), dtype='float32'), trainable=True)  

    def call(self, inputs, **kwargs):
        return inputs + self.bias
    
    def compute_output_shape(self, input_shape):
        return input_shape  

def WeightScalingDense(x, filters, gain, use_pixelnorm=False, activate=None):
    init = RandomNormal(mean=0., stddev=1.)
    in_filters = backend.int_shape(x)[-1]
    x = layers.Dense(filters, use_bias=False, kernel_initializer=init, dtype='float32')(x)
    x = WeightScaling(shape=(in_filters), gain=gain)(x)
    x = Bias(input_shape=x.shape)(x)
    if activate=='LeakyReLU':
        x = layers.LeakyReLU(0.2)(x)
    elif activate=='tanh':
        x = layers.Activation('tanh')(x)
    
    if use_pixelnorm:
        x = PixelNormalization()(x)
    return x

def WeightScalingConv(x, filters, kernel_size, gain, use_pixelnorm=False, activate=None, strides=(1,1)):
    init = RandomNormal(mean=0., stddev=1.)
    in_filters = backend.int_shape(x)[-1]
    x = layers.Conv2D(filters, kernel_size, strides=strides, use_bias=False, padding="same", kernel_initializer=init, dtype='float32')(x)
    x = WeightScaling(shape=(kernel_size[0], kernel_size[1], in_filters), gain=gain)(x)
    x = Bias(input_shape=x.shape)(x)
    if activate=='LeakyReLU':
        x = layers.LeakyReLU(0.2)(x)
    elif activate=='tanh':
        x = layers.Activation('tanh')(x)
    
    if use_pixelnorm:
        x = PixelNormalization()(x)
    return x 

In [None]:
def WeightScalingSeparableConv(x, filters, kernel_size, gain, use_pixelnorm=False, activate=None, strides=(1,1)):
    init = RandomNormal(mean=0., stddev=1.)
    in_filters = backend.int_shape(x)[-1]
    x = layers.SeparableConv2D(filters, kernel_size, strides=strides, use_bias=False, padding="same", kernel_initializer=init, dtype='float32')(x)
    x = WeightScaling(shape=(kernel_size[0], kernel_size[1], in_filters), gain=gain)(x)
    x = Bias(input_shape=x.shape)(x)
    if activate=='LeakyReLU':
        x = layers.LeakyReLU(0.2)(x)
    elif activate=='tanh':
        x = layers.Activation('tanh')(x)
    
    if use_pixelnorm:
        x = PixelNormalization()(x)
    return x 

In [None]:
def WeightScalingResConvBlock(x, filters, kernel_size, gain, use_pixelnorm=False, activate=None):
    x_in = x
    x = WeightScalingConv(x, filters, kernel_size, gain, activate, use_pixelnorm)
    x = WeightScalingConv(x, filters, kernel_size, gain, activate, use_pixelnorm)
    
    x_skip = WeightScalingConv(x_in, filters, kernel_size=(1,1), gain=gain, activate='LeakyReLU', use_pixelnorm=True)
    
    x = layers.Add()([x, x_skip])
    
    return x 

In [None]:
class PGAN(Model):
    def __init__(
        self,
        latent_dim,
        G1,
        G2,
        d_steps=1,
        gp_weight=10.0,
        drift_weight=0.001        
    ):
        super(PGAN, self).__init__()
        self.latent_dim = latent_dim
        self.d_steps = d_steps
        self.gp_weight = gp_weight
        self.drift_weight = drift_weight
        self.n_depth = 0
        self.stage = -1
        self.discriminator = self.init_discriminator()
        self.discriminator_wt_fade = None
        self.generator = self.init_generator()
        self.generator_wt_fade = None
        self.n_clusters = 4
        self.G1 = G1
        self.G2 = G2
        self.segmodule = self.segment_module()

    def call(self, inputs):
        return

    def init_discriminator(self):
        img_input = layers.Input(shape = (4,4,3))
        img_input_cast = tf.cast(img_input, tf.float32)
        
        # fromRGB
        x = WeightScalingConv(img_input_cast, filters=FILTERS[0], kernel_size=(1,1), gain=np.sqrt(2), activate='LeakyReLU')
        
        # Add Minibatch end of discriminator
        x = MinibatchStdev()(x)

        x = WeightScalingConv(x, filters=FILTERS[0], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU')
        x = WeightScalingConv(x, filters=FILTERS[0], kernel_size=(4,4), gain=np.sqrt(2), activate='LeakyReLU', strides=(4,4))

        x = layers.Flatten()(x)
        # Gain should be 1, cos it's a last layer 
        x = WeightScalingDense(x, filters=1, gain=1.)

        d_model = Model(img_input, x, name='discriminator')
        d_model.summary()
        
        return d_model

    # Fade in upper resolution block
    def fade_in_discriminator(self):
        #for layer in self.discriminator.layers:
        #    layer.trainable = False
        input_shape = list(self.discriminator.input.shape)
        # 1. Double the input resolution. 
        input_shape = (input_shape[1]*2, input_shape[2]*2, input_shape[3])
        img_input = layers.Input(shape = input_shape)
        img_input_cast = tf.cast(img_input, tf.float32)

        # 2. Add pooling layer 
        #    Reuse the existing “formRGB” block defined as “x1".
        x1 = layers.AveragePooling2D()(img_input_cast)
        x1 = self.discriminator.layers[1](x1, dtype=tf.float32) # Conv2D FromRGB
        x1 = self.discriminator.layers[2](x1) # WeightScalingLayer
        x1 = self.discriminator.layers[3](x1) # Bias
        x1 = self.discriminator.layers[4](x1) # LeakyReLU
        print(x1.shape)

        # 3.  Define a "fade in" block (x2) with a new "fromRGB" and two 3x3 convolutions. 
        #     Add an AveragePooling2D layer
        x2 = WeightScalingConv(img_input_cast, filters=FILTERS[self.n_depth], kernel_size=(1,1), gain=np.sqrt(2), activate='LeakyReLU')

        x2 = WeightScalingConv(x2, filters=FILTERS[self.n_depth], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU')
        x2 = WeightScalingConv(x2, filters=FILTERS[self.n_depth-1], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU')

        x2 = layers.AveragePooling2D()(x2)
        print(x2.shape)

        # 4. Weighted Sum x1 and x2 to smoothly put the "fade in" block. 
        x = WeightedSum()([x1, x2])

        # Define stabilized(c. state) discriminator 
        for i in range(5, len(self.discriminator.layers)):
            x2 = self.discriminator.layers[i](x2)
        self.discriminator_stabilize = Model(img_input, x2, name='discriminator')

        # 5. Add existing discriminator layers. 
        for i in range(5, len(self.discriminator.layers)):
            x = self.discriminator.layers[i](x)
        self.discriminator = Model(img_input, x, name='discriminator')

        self.discriminator.summary()



    # Change to stabilized(c. state) discriminator 
    def stabilize_discriminator(self):
        self.discriminator = self.discriminator_stabilize
        self.discriminator.summary()

    def segment_module(self):
        img_input = layers.Input(shape = (256,256,3))
        img_input_cast = tf.cast(img_input, tf.float32)
        
        seg_output = self.G1(img_input_cast)
        mask_output = create_mask(seg_output, self.n_clusters)
        mask_output = tf.reshape(mask_output, (-1, 256, 256, 1))
        x = tf.cast(mask_output, dtype=tf.float32)
        
        ###################################
        *feature, _ = self.G2(mask_output)
        
        seg_model = Model(img_input, outputs = [*feature], name='segmodule')
        ###################################
        seg_model.summary()
        
        return seg_model
        

    def init_generator(self):
        noise = layers.Input(shape=(self.latent_dim,))
        x = PixelNormalization()(noise)
        x = WeightScalingDense(x, filters=4*4*FILTERS[0], gain=np.sqrt(2)/4, activate='LeakyReLU', use_pixelnorm=True)
        x = layers.Reshape((4, 4, FILTERS[0]))(x)
        x = WeightScalingConv(x, filters=FILTERS[0], kernel_size=(4,4), gain=np.sqrt(2), activate='LeakyReLU', use_pixelnorm=True)
        x = WeightScalingConv(x, filters=FILTERS[0], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU', use_pixelnorm=True)
        x = WeightScalingConv(x, filters=3, kernel_size=(1,1), gain=1., activate='tanh', use_pixelnorm=False)

        g_model = Model(noise, x, name='generator')
        g_model.summary()
        return g_model

    # Fade in upper resolution block
    def fade_in_generator(self):
        # 1. Get the node above the “toRGB” block 
        block_end = self.generator.layers[-5].output
        # 2. Double block_end       
        block_end = layers.UpSampling2D((2,2))(block_end)
        # 3. Reuse the existing “toRGB” block defined as“x1”. 
        x1 = self.generator.layers[-4](block_end) # Conv2d
        x1 = self.generator.layers[-3](x1) # WeightScalingLayer
        x1 = self.generator.layers[-2](x1) # Bias
        x1 = self.generator.layers[-1](x1) #tanh
        # 4. Define a "fade in" block (x2) with two 3x3 convolutions and a new "toRGB".
        x2 = WeightScalingConv(block_end, filters=FILTERS[self.n_depth], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU', use_pixelnorm=True)
        x2 = WeightScalingConv(x2, filters=FILTERS[self.n_depth], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU', use_pixelnorm=True)             
        x2 = WeightScalingConv(x2, filters=3, kernel_size=(1,1), gain=1., activate='tanh', use_pixelnorm=False)    
        # Define stabilized(c. state) generator
        self.generator_stabilize = Model(self.generator.input, x2, name='generator')
        # 5.Then "WeightedSum" x1 and x2 to smoothly put the "fade in" block.
        x = WeightedSum()([x1, x2])
        self.generator = Model(self.generator.input, x, name='generator')
        self.generator.summary()

    def fade_in_generator_with_seg(self, n_depth):
        self.stage = n_depth - 2

        if self.stage == 0:
            feature, _, _, _ = self.segmodule.outputs
            
            feature_cast = tf.cast(feature, tf.float32)
        elif self.stage == 1:
            _, feature, _, _ = self.segmodule.outputs
            
            feature_cast = tf.cast(feature, tf.float32)
        elif self.stage == 2:
            _, _, feature, _ = self.segmodule.outputs
            
            feature_cast = tf.cast(feature, tf.float32)
        else:
            print('Value of the self.stage is weird!')
        
        block_end = self.generator.layers[-5].output
        block_end = layers.UpSampling2D((2,2))(block_end)
        
        x1 = self.generator.layers[-4](block_end) # Conv2d
        x1 = self.generator.layers[-3](x1) # WeightScalingLayer
        x1 = self.generator.layers[-2](x1) # Bias
        x1 = self.generator.layers[-1](x1) #tanh
        
        x2 = layers.Concatenate()([block_end, feature_cast])
        x2 = WeightScalingConv(x2, filters=FILTERS[self.n_depth], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU', use_pixelnorm=True)
        x2 = WeightScalingConv(x2, filters=FILTERS[self.n_depth], kernel_size=(3,3), gain=np.sqrt(2), activate='LeakyReLU', use_pixelnorm=True)
        
        x2 = WeightScalingConv(x2, filters=3, kernel_size=(1,1), gain=1., activate='tanh', use_pixelnorm=False)
        
        if self.stage == 0:
            self.generator_stabilize = Model(inputs=[self.generator.input, self.segmodule.input], outputs=x2, name='generator')
        
            x = WeightedSum()([x1, x2])
            self.generator = Model(inputs=[self.generator.input, self.segmodule.input], outputs=x, name='generator')

        else:
            self.generator_stabilize = Model(self.generator.input, outputs=x2, name='generator')
        
            x = WeightedSum()([x1, x2])
            self.generator = Model(self.generator.input, outputs=x, name='generator')
            
        self.generator.summary()
        
        
    # Change to stabilized(c. state) generator 
    def stabilize_generator(self):
        self.generator = self.generator_stabilize
        self.generator.summary()


    def compile(self, d_optimizer, g_optimizer):
        super(PGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """ Calculates the gradient penalty.
        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0.0, maxval=1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, images):
        real_images, images_256 = images
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]
        
        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                if self.stage == -1:
                    fake_images = self.generator(random_latent_vectors, training=True)  
                elif self.stage >= 0:
                    fake_images = self.generator([random_latent_vectors, images_256], training=True)
                elif self.stage >= 2:
                    fake_images = self.generator([random_latent_vectors, images_256], training=True)        
                
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)

                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)

                # Calculate the drift for regularization
                drift = tf.reduce_mean(tf.square(real_logits))

                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + self.gp_weight * gp + self.drift_weight * drift

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            if self.stage == -1:
                generated_images = self.generator(random_latent_vectors, training=True)  
            elif self.stage >= 0:
                generated_images = self.generator([random_latent_vectors, images_256], training=True)
            elif self.stage >= 2:
                generated_images = self.generator([random_latent_vectors, images_256], training=True) 
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = -tf.reduce_mean(gen_img_logits)
        # Get the gradients w.r.t the generator loss
        g_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(zip(g_gradient, self.generator.trainable_variables))
        return {'d_loss': d_loss, 'g_loss': g_loss}

In [None]:
# Create a Keras callback that periodically saves generated images and updates alpha in WeightedSum layers
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, imgs, num_img=16, latent_dim=512, prefix=''):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.val_imgs = imgs
        self.random_latent_vectors = tf.random.normal(shape=[num_img, self.latent_dim], seed=9434)
        self.steps_per_epoch = 0
        self.epochs = 0
        self.steps = self.steps_per_epoch * self.epochs
        self.n_epoch = 0
        self.n_depth = 0
        self.prefix = prefix
  
    def set_prefix(self, prefix=''):
        self.prefix = prefix
  
    def set_n_depth(self, n_depth):
        self.n_depth = n_depth

    def set_steps(self, steps_per_epoch, epochs):
        self.steps_per_epoch = steps_per_epoch
        self.epochs = epochs
        self.steps = self.steps_per_epoch * self.epochs

    def on_epoch_begin(self, epoch, logs=None):
        self.n_epoch = epoch


    def on_epoch_end(self, epoch, logs=None):
        #samples = self.model.generator([self.random_latent_vectors, val_imgs])
        if self.n_depth <= 1:  # 8X8
            samples = self.model.generator(self.random_latent_vectors)
        elif self.n_depth <= 4: # 16x16, 32x32, 64x64
            samples = self.model.generator([self.random_latent_vectors, val_imgs])
        else:
            samples = self.model.generator([self.random_latent_vectors, val_imgs])
        
        samples = (samples * 0.5) + 0.5
        n_grid = int(sqrt(self.num_img))

        fig, axes = pyplot.subplots(n_grid, n_grid, figsize=(4*n_grid, 4*n_grid))
        sample_grid = np.reshape(samples[:n_grid * n_grid], (n_grid, n_grid, samples.shape[1], samples.shape[2], samples.shape[3]))

        for i in range(n_grid):
            for j in range(n_grid):
                axes[i][j].set_axis_off()
                samples_grid_i_j = Image.fromarray((sample_grid[i][j] * 255).astype(np.uint8))
                samples_grid_i_j = samples_grid_i_j.resize((128,128))
                axes[i][j].imshow(np.array(samples_grid_i_j))
        title = f'images_120k/plot_{self.prefix}_{epoch:05d}.png'
        pyplot.savefig(title, bbox_inches='tight')
        print(f'\n saved {title}')
        pyplot.close(fig)
  

    def on_batch_begin(self, batch, logs=None):
        # Update alpha in WeightedSum layers
        alpha = ((self.n_epoch * self.steps_per_epoch) + batch) / float(self.steps - 1)
        #print(f'\n {self.steps}, {self.n_epoch}, {self.steps_per_epoch}, {alpha}')
        for layer in self.model.generator.layers:
            if isinstance(layer, WeightedSum):
                backend.set_value(layer.alpha, alpha)
        for layer in self.model.discriminator.layers:
            if isinstance(layer, WeightedSum):
                backend.set_value(layer.alpha, alpha)

In [None]:
NOISE_DIM = 512
# Set the number of batches, epochs and steps for trainining.
BATCH_SIZE = [128, 64, 64, 32, 32, 16, 8]
EPOCHS = 50
STEPS_PER_EPOCH = 1000

In [None]:
def parse_tfrecord_tf(record):
    features = tf.io.parse_single_example(record, features={
        'shape': tf.io.FixedLenFeature([3], tf.int64),
        'data': tf.io.FixedLenFeature([], tf.string)})
    raw_data = tf.io.decode_raw(features['data'], tf.uint8)
    float_data = tf.cast(tf.reshape(raw_data, features['shape']), dtype=tf.float32)
    data = float_data / 128.5 -1
    return data

In [None]:
# tfrecords 파일 경로 list로 입력 : 4x4 크기부터 256x256 까지 
tfr_file = ['../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r02.tfrecords', '../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r03.tfrecords',
            '../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r04.tfrecords', '../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r05.tfrecords',
            '../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r06.tfrecords', '../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r07.tfrecords',
            '../Molpaxbio/ODG-tfrec-3000/ODG-tfrec-3000-r08.tfrecords']
buffer_mb       = 256

In [None]:
train_dataset = tf.data.TFRecordDataset(tfr_file[0], compression_type='', buffer_size=buffer_mb<<20)
train_dataset = train_dataset.map(parse_tfrecord_tf)#, num_parallel_calls=num_threads)
train_dataset = train_dataset.batch(BATCH_SIZE[0]).repeat()

In [None]:
train_256 = tf.data.TFRecordDataset(tfr_file[6], compression_type='', buffer_size=buffer_mb<<20)
train_256 = train_256.map(parse_tfrecord_tf)#, num_parallel_calls=num_threads)
train_256 = train_256.take(10000) # .take() 앞에서부터 10000장만 사용
train_256 = train_256.batch(BATCH_SIZE[0]).repeat()

In [None]:
val_256 = tf.data.TFRecordDataset(tfr_file[6], compression_type='', buffer_size=buffer_mb<<20)
val_256 = val_256.map(parse_tfrecord_tf)#, num_parallel_calls=num_threads)
val_256 = val_256.take(64).batch(64)
val_imgs = list(val_256.as_numpy_iterator())

In [None]:
concat_train_dataset = tf.data.Dataset.zip((train_dataset,train_256))

In [None]:
generator_optimizer = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

In [None]:
cbk = GANMonitor(imgs=val_imgs, num_img=64, latent_dim=NOISE_DIM, prefix='0_init')
cbk.set_steps(steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS)

In [None]:
# Instantiate the PGAN(PG-GAN) model.
pgan = PGAN(
    latent_dim = NOISE_DIM,
    G1=G1,
    G2=G2,
    d_steps = 1,
)

In [None]:
# checkpoint 경로 설정
checkpoint_path = f"ckpts_120k/pgan_{cbk.prefix}.ckpt"

# Compile models
pgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
)

In [None]:
# initialize first step model (4x4)
pgan.fit(concat_train_dataset, steps_per_epoch = STEPS_PER_EPOCH, epochs = 1, callbacks=[cbk])
pgan.save_weights(checkpoint_path)

In [None]:
# 순차적으로 scale에 따른 학습 진행
for n_depth in range(1, 7):
    # Set current level(depth)
    pgan.n_depth = n_depth

    # Set parameters like epochs, steps, batch size and image size
    steps_per_epoch = STEPS_PER_EPOCH
    epochs = int(EPOCHS*(BATCH_SIZE[0]/BATCH_SIZE[n_depth]))
    
    train_dataset = tf.data.TFRecordDataset(tfr_file[n_depth], compression_type='', buffer_size=buffer_mb<<20)
    train_dataset = train_dataset.map(parse_tfrecord_tf)#, num_parallel_calls=num_threads)
    train_dataset = train_dataset.take(120000)
    train_dataset = train_dataset.batch(BATCH_SIZE[n_depth]).repeat()
    
    train_256 = tf.data.TFRecordDataset(tfr_file[6], compression_type='', buffer_size=buffer_mb<<20)
    train_256 = train_256.map(parse_tfrecord_tf)#, num_parallel_calls=num_threads)
    train_256 = train_256.take(120000)
    train_256 = train_256.batch(BATCH_SIZE[n_depth]).repeat()
    
    concat_train_dataset = tf.data.Dataset.zip((train_dataset,train_256))
    
    cbk.set_prefix(prefix=f'{n_depth}_fade_in')
    cbk.set_n_depth(n_depth=n_depth)
    cbk.set_steps(steps_per_epoch=steps_per_epoch, epochs=epochs)
                                    
    # Put fade in generator and discriminator
    if n_depth <= 1:  # 8X8
        pgan.fade_in_generator()
        pgan.fade_in_discriminator()
    elif n_depth <= 4: # 16x16, 32x32, 64x64
        pgan.fade_in_generator_with_seg(n_depth = n_depth)
        pgan.fade_in_discriminator()
    else:
        pgan.fade_in_generator()
        pgan.fade_in_discriminator()

    # Draw fade in generator and discriminator
    #tf.keras.utils.plot_model(pgan.generator, to_file=f'generator_{n_depth}_fade_in.png', show_shapes=True)
    #tf.keras.utils.plot_model(pgan.discriminator, to_file=f'discriminator_{n_depth}_fade_in.png', show_shapes=True)

    pgan.compile(
        d_optimizer=discriminator_optimizer,
        g_optimizer=generator_optimizer,
    )
    # Train fade in generator and discriminator
    if n_depth <= 1:  # 8X8
        pgan.fit(concat_train_dataset, steps_per_epoch = steps_per_epoch, epochs = epochs, callbacks=[cbk])
    elif n_depth <= 4: # 16x16, 32x32, 64x64
        pgan.fit(concat_train_dataset, steps_per_epoch = steps_per_epoch, epochs = epochs, callbacks=[cbk])
    else:
        pgan.fit(concat_train_dataset, steps_per_epoch = steps_per_epoch, epochs = epochs, callbacks=[cbk])
        
    #pgan.fit(train_dataset, steps_per_epoch = steps_per_epoch, epochs = epochs, callbacks=[cbk])
    # Save models
    checkpoint_path = f"ckpts_120k/pgan_{cbk.prefix}.ckpt"
    pgan.save_weights(checkpoint_path)

    # Change to stabilized generator and discriminator
    cbk.set_prefix(prefix=f'{n_depth}_stabilize')
    pgan.stabilize_generator()
    pgan.stabilize_discriminator()

    # Draw stabilized generator and discriminator
    tf.keras.utils.plot_model(pgan.generator, to_file=f'generator_{n_depth}_stabilize.png', show_shapes=True)
    tf.keras.utils.plot_model(pgan.discriminator, to_file=f'discriminator_{n_depth}_stabilize.png', show_shapes=True)
    pgan.compile(
        d_optimizer=discriminator_optimizer,
        g_optimizer=generator_optimizer,
    )
    # Train stabilized generator and discriminator
    pgan.fit(concat_train_dataset, steps_per_epoch = steps_per_epoch, epochs = epochs, callbacks=[cbk])
    # Save models
    checkpoint_path = f"ckpts_120k/pgan_{cbk.prefix}.ckpt"
    pgan.save_weights(checkpoint_path)

이미지 저장 

In [None]:
def saveSample(generator, random_latent_vectors, val_imgs, n_depth, prefix):
    if n_depth <= 1:  # 8X8
        samples = generator(random_latent_vectors)
    elif n_depth <= 4: # 16x16, 32x32, 64x64
        samples = generator([random_latent_vectors, val_imgs])
    else:
        samples = generator([random_latent_vectors, val_imgs])

    samples = (samples * 0.5) + 0.5
    n_grid = int(sqrt(random_latent_vectors.shape[0]))
  
    fig, axes = pyplot.subplots(n_grid, n_grid, figsize=(8*n_grid, 8*n_grid))
    sample_grid = np.reshape(samples[:n_grid * n_grid], (n_grid, n_grid, samples.shape[1], samples.shape[2], samples.shape[3]))
  
    for i in range(n_grid):
        for j in range(n_grid):
            axes[i][j].set_axis_off()
            samples_grid_i_j = Image.fromarray((sample_grid[i][j] * 255).astype(np.uint8))
            samples_grid_i_j = samples_grid_i_j.resize((256,256))
            axes[i][j].imshow(np.array(samples_grid_i_j))
    
    pyplot.subplots_adjust(wspace=0.1, hspace=0.1)
    
    title = f'test_120k/plot_{prefix}_{0:05d}.png'
    pyplot.savefig(title, bbox_inches='tight')
    print(f'\n saved {title}')
    pyplot.close(fig)  

In [None]:
# 학습된 모델 Load (Inference시 사용)
NOISE_DIM = 512
NUM_SAMPLE = 64
n_depth = 0
random_latent_vectors = tf.random.normal(shape=[NUM_SAMPLE, NOISE_DIM])#, seed=9434)
val_imgs = tf.data.TFRecordDataset(tfr_file[6], compression_type='', buffer_size=buffer_mb<<20)
val_imgs = val_imgs.map(parse_tfrecord_tf)#, num_parallel_calls=num_threads)
val_imgs = val_imgs.take(NUM_SAMPLE).batch(NUM_SAMPLE)
val_imgs = list(val_imgs.as_numpy_iterator())

# Instantiate the PGAN(PG-GAN) model.
pgan = PGAN(
    latent_dim = NOISE_DIM,
    G1=G1,
    G2=G2,
    d_steps = 1,
)

# Load weight and generate samples per each level. 
prefix='0_init'
pgan.load_weights(f"ckpts_120k/pgan_{prefix}.ckpt")
#saveSample(pgan.generator, random_latent_vectors, val_imgs, n_depth, prefix)

#inference
for n_depth in range(1,6):
    pgan.n_depth = n_depth
    prefix=f'{n_depth}_fade_in'
    
    if n_depth <= 1:  # 8X8
        pgan.fade_in_generator()
        pgan.fade_in_discriminator()
    elif n_depth <= 4: # 16x16, 32x32, 64x64
        pgan.fade_in_generator_with_seg(n_depth = n_depth)
        pgan.fade_in_discriminator()
    else:
        pgan.fade_in_generator()
        pgan.fade_in_discriminator()
  
    pgan.load_weights(f"ckpts_120k/pgan_{prefix}.ckpt")
    #saveSample(pgan.generator, random_latent_vectors, val_imgs, n_depth, prefix)
  
    prefix=f'{n_depth}_stabilize'
    pgan.stabilize_generator()
    pgan.stabilize_discriminator()
  
    pgan.load_weights(f"ckpts_120k/pgan_{prefix}.ckpt")
    #saveSample(pgan.generator, random_latent_vectors, val_imgs, n_depth, prefix)
print('###############')
pgan.load_weights(f"ckpts_120k/pgan_{prefix}.ckpt")

In [None]:
random_latent_vectors = tf.random.normal(shape=[NUM_SAMPLE, NOISE_DIM])
prefix=f'{n_depth}_100_stabilize'
saveSample(pgan.generator, random_latent_vectors, val_imgs[99], n_depth, prefix)

In [None]:
random_latent_vectors = tf.random.normal(shape=[5, 30])