In [None]:
%load_ext autoreload
%autoreload 2
%pylab inline
from models import *
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
!nvidia-smi | head -19

In [None]:
class progGAN(BaseModel):
    ''''''
    def __init__(self, size):
        ''''''
        self.size = size
        self.G = NVIDIA_generator(size)
        self.D = NVIDIA_discriminator(size)
        super(BaseModel, self).__init__(
            [i for m in (self.G, self.D) for i in m.inputs],
            [o for m in (self.G, self.D) for o in m.outputs])      
    
    def train(self, input_dir, output_dir='./out', batch_size=1, period=100., lambda_GP=10, 
              lr=1e-3, n_critic=1, gamma=750, epsilon_drift=0.001, summary_every=np.inf):
        ''''''
        print('Building inputs')
        with tf.variable_scope('Inputs'):
            coord = tf.train.Coordinator()
            z = tf.random_normal(shape=(batch_size, 512))
            x = self.stream_input([input_dir], self.size, batch_size)
            #x = tf.stack([tf.random_crop(x, [self.size, self.size, 3]) for _ in range(batch_size)], 0)
            xs = [tf.image.resize_bilinear(x, [s, s]) for s in [
                2**(2+i) for i in range(int(np.log2(self.size//4))+1)]]
            xs_hat = self.G(z)
        
        print('Building losses')
        with tf.variable_scope('Losses'):
            
            step = tf.Variable(0, name='step')
            tooth, blend, quantize = self.blend(step, period=period, n_scales=len(xs))
            
            #xs_hat = self.residual(xs_hat, tooth)
            real = self.D(xs)
            fake = self.D(xs_hat)
            GP = self.grad_penalty(xs, xs_hat)
            
            real_mask = tf.reduce_mean(tf.reduce_sum(real * blend, axis=-1))
            fake_mask = tf.reduce_mean(tf.reduce_sum(fake * blend, axis=-1))
            
            GP_mask = tf.reduce_mean(tf.reduce_sum(GP * blend, axis=-1))
            drift_mask = tf.reduce_mean(tf.reduce_sum(real**2 * blend, axis=-1))
            
            L_G = -fake_mask
            L_D = -real_mask + fake_mask 
            L_D_tot = L_D + epsilon_drift*drift_mask + lambda_GP*GP_mask
            # TODO: add feature regularizer
        
        print('Building optimizers')
        with tf.variable_scope('Optimizers'):
            G_opt = tf.train.AdamOptimizer(lr, beta1=0, beta2=0.99).minimize(
                L_G, var_list=self.G.trainable_weights, global_step=step)
            D_opt = tf.train.AdamOptimizer(lr, beta1=0, beta2=0.99).minimize(
                L_D_tot, var_list=self.D.trainable_weights)
            
        print('Building summary')
        with tf.variable_scope('Summary'):
            img_dict={
                'real_'+str(self.G.output_shape[i][1]): self.postproc_img(x) \
                for i, x in enumerate(xs)}
            fake_img_dict={
                'fake_'+str(self.G.output_shape[i][1]): self.postproc_img(x) \
                for i, x in enumerate(xs_hat)}
            img_dict.update(fake_img_dict)
            scalar_dict={
                'L_D_tot': L_D_tot,
                'L_D': L_D, 
                'L_G': L_G,
                'GP_mask': GP_mask,
                'drift_mask': drift_mask}
        summary, writer = self.make_summary(output_dir, 
            img_dict=img_dict, scalar_dict=scalar_dict, graph=None)
        
        print('Initializing')
        with tf.Session() as sess:
            # TODO: scale weights at runtime (section 4.1)
            sess.run(tf.global_variables_initializer())
            tf.train.start_queue_runners(sess=sess, coord=coord)
            self.graph.finalize()
            print('Start of training'); time.sleep(1); n = 0
            try:
                while not coord.should_stop():
                    for i in tqdm.trange(10000, disable=True):
                        for _ in range(n_critic):
                            sess.run(D_opt)
                        if n % summary_every:
                            n = sess.run([G_opt, step])[1]
                        else: 
                            s, n = sess.run([G_opt, summary, step])[1:]
                            writer.add_summary(s, n)
                            writer.flush()
            except:
                coord.request_stop()
                time.sleep(1)
                raise

In [None]:
pg = progGAN(32)
print('Gen params ~ {:0.1e}\nDisc params ~ {:0.1e}'.format(
    pg.G.count_params(), pg.D.count_params()))

In [None]:
pg.train(
    input_dir='/data/datasets/celeba/img_align_celeba/', 
    output_dir='out/tooth',#'./out/celeba_32_period100k_edrift0.01_QUANTIZE', 
    batch_size=8, summary_every=10,
    period=1000, n_critic=1, lr=1e-3, 
    lambda_GP=10, gamma=750, epsilon_drift=0.01,
)

In [None]:
rm -r out/tooth

In [None]:
def tri(s):
    scale = arange(8)
    init = clip(scale+2-s, 0, 1)
    tri = 1 - clip(abs(scale + 1 - s), 0, 1)
    #ramp = maximum(1 - (s - scale - 1), 0)
    #tooth = ramp * (1 - greater(ramp, 1).astype(float))
    return concatenate([init[0:1],  tri[1:]])
imshow([tri(s) for s in linspace(0, 8, 100)], aspect='auto')