In [None]:
%load_ext autoreload
%autoreload 2
%pylab inline
from models import *
import tqdm
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
class progGAN(tf.keras.models.Model):
    ''''''
    def __init__(self):
        ''''''
        self.G = NVIDIA_generator()
        self.D = NVIDIA_discriminator()
        super(tf.keras.models.Model, self).__init__(
                [i for m in (self.G, self.D) for i in m.inputs],
                [o for m in (self.G, self.D) for o in m.outputs])
        
    def blend(self, step, period=100.):
        ''''''
        scale = tf.range(8, dtype=tf.float32)
        state = tf.minimum(tf.cast(step, tf.float32) / tf.cast(period, tf.float32), 7)
        return tf.maximum(tf.ones_like(scale) - tf.abs(state - scale), tf.zeros_like(scale))
    
    def grad_penalty(self, xs, xs_hat, blend):
        '''gradient penalty from arxiv.org/pdf/1704.00028.pdf'''
        interps = []
        for i in range(len(xs)):
            alpha = tf.random_uniform(shape=[tf.shape(xs[i])[0],1,1,1])
            interps.append((1-alpha)*xs[i] + alpha*xs_hat[i])
        grads = tf.gradients(self.D(interps)*blend, interps)[:-1]
        slopes = tf.reduce_sum(
            [tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1,2,3])) for grad in grads],
            axis=0)
        return tf.reduce_mean((slopes-1)**2)
    
    def train(self, x, batch_size=1, period=100., lambda_GP=10, lr=1e-3):
        ''''''
        z = tf.random_normal(shape=(batch_size, 512))
        x = tf.random_crop(x, [batch_size, 1024, 1024, 3])
        xs = [tf.image.resize_bilinear(x, [s, s]) for s in [
            1024, 512, 256, 128, 64, 32, 16, 8, 4]]
        xs_hat = self.G(z)
        
        step = tf.Variable(0, name='step')
        blend = self.blend(step, period=period)[::-1] # reverse (hi-res last)
        real = tf.reduce_sum(self.D(xs) * blend)
        fake = tf.reduce_sum(self.D(xs_hat) * blend)
        GP = self.grad_penalty(xs, xs_hat, blend)
                
        L_G = fake
        L_D = real - fake + lambda_GP*GP
        
        G_opt = tf.train.AdamOptimizer(lr, beta1=0, beta2=0.99).minimize(
            L_G, var_list=self.G.trainable_weights)
        D_opt = tf.train.AdamOptimizer(lr, beta1=0, beta2=0.99).minimize(
            L_D, var_list=self.D.trainable_weights, global_step=step)
        
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            while 1:
                for i in tqdm.trange(100, disable=True):
                    print('L_G: '+str(sess.run([G_opt, L_G])[1]))
                    print('L_D: '+str(sess.run([D_opt, L_D])[1]))

In [None]:
pg = progGAN()

In [None]:
x = expand_dims(imread('/data/datasets/Hi_res_stills/Cosmic/Cosmic_003.jpg'), 0).astype('float32')
x /= 127.5
x -= 1

In [None]:
pg.train(x, batch_size=1, period=1000)