1. https://github.com/rosinality/glow-pytorch
2. https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html
3. https://blog.evjang.com/2018/01/nf1.html
4. https://blog.evjang.com/2018/01/nf2.html
5. http://akosiorek.github.io/ml/2018/04/03/norm_flows.html
6. https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html
7. https://www.youtube.com/playlist?list=PLdlPlO1QhMiAkedeu0aJixfkknLRxk1nA
8. https://ml4a.github.io/classes/itp-F18/09/#
9. https://openai.com/blog/glow/
10. https://www.youtube.com/watch?v=kIvyex9HH0Y

In [None]:
import os
import sys
import time

from PIL import Image
# import argparse
import imutils
import dlib
import cv2

import numpy as np
import tensorflow as tf
import graphics
from utils import ResultLogger
import matplotlib.pyplot as plt
from recordtype import recordtype

In [None]:
hps1 = recordtype('hps1', ['problem','seed','logdir','dal','image_size','lr','n_levels','direct_iterator','epochs',
                           'epochs_warmup','epochs_full_valid','epochs_full_sample','n_batch_train','local_batch_train',
                           'local_batch_test','n_test','n_y','n_train','train_its','test_its','full_test_its','inference',
                           'rnd_crop','anchor_size','n_batch_init','local_batch_init','n_bits_x','n_bins',
                           'flow_permutation','flow_coupling','depth','width','top_shape','learntop','ycond',
                           'weight_y','gradient_checkpointing','optimizer','polyak_epochs','beta1',
                           'weight_decay','restore_path','category', 'data_dir','n_batch_test','n_sample',
                           'verbose'])

# Our new "Car" class works as expected:
hps = hps1('mnist',1,'./logs',1,32,0.001,5,False,200,
           10,50,50,64,50,
           None,100,10,1500,None,None,None,False,
           None,32,256,None,5,None,
           2,0,8,512,None,True,False,
           0.0,1,'adamax',1,0.9,
           1.0,'','','',50,1,
           True)


hvd1 = recordtype('hps1', ['size','rank'])
hvd = hvd1(1,0)

In [None]:
# def align_images(x,hps):
#     x = imutils.resize(x, width=hps.image_size, height=hps.image_size)
#     x = x.reshape(hps.image_size, hps.image_size, 1)
#     return x

# x_train = np.array([align_images(x,hps) for x in x_train])
# x_test = np.array([align_images(x,hps) for x in x_test])

In [None]:
def tensorflow_session():
    # Init session and params
    sess = tf.Session()
    return sess

In [None]:
def get_data(hps, sess):
    if hps.problem == 'lsun_realnvp':
        hps.rnd_crop = True
    else:
        hps.rnd_crop = False

    # Use anchor_size to rescale batch size based on image_size
    s = hps.anchor_size
    hps.local_batch_train = hps.n_batch_train * \
        s * s // (hps.image_size * hps.image_size)

    hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[
        hps.local_batch_train]  # round down to closest divisor of 50
      
    hps.local_batch_init = hps.n_batch_init * \
        s * s // (hps.image_size * hps.image_size)

    print("Rank {} Batch sizes Train {} Test {} Init {}".format(
        hvd.rank, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init))

    if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']:
        hps.direct_iterator = True
        import data_loaders.get_data as v
        train_iterator, test_iterator, data_init = \
            v.get_data(sess, hps.data_dir, hvd.size, hvd.rank, hps.pmap, hps.fmap, hps.local_batch_train,
                       hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop)

    elif hps.problem in ['mnist', 'cifar10']:
        hps.direct_iterator = False
        import data_loaders.get_mnist_cifar as v
        train_iterator, test_iterator, data_init = \
            v.get_data(hps.problem, hvd.size, hvd.rank, hps.dal, hps.local_batch_train,
                       hps.local_batch_test, hps.local_batch_init, hps.image_size)

    else:
        raise Exception()

    return train_iterator, test_iterator, data_init

In [None]:
def get_its(hps):
    # These run for a fixed amount of time. As anchored batch is smaller, we've actually seen fewer examples
    train_its = int(np.ceil(hps.n_train / (hps.n_batch_train * hvd.size)))
    test_its = int(np.ceil(hps.n_test / (hps.n_batch_train * hvd.size)))
    train_epoch = train_its * hps.n_batch_train * hvd.size

    # Do a full validation run
    if hvd.rank == 0:
        print(hps.n_test, hps.local_batch_test, hvd.size)
    assert hps.n_test % (hps.local_batch_test * hvd.size) == 0
    full_test_its = hps.n_test // (hps.local_batch_test * hvd.size)

    if hvd.rank == 0:
        print("Train epoch size: " + str(train_epoch))
    return train_its, test_its, full_test_its

In [None]:
def infer(sess, model, hps, iterator):
    # Example of using model in inference mode. Load saved model using hps.restore_path
    # Can provide x, y from files instead of dataset iterator
    # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
    if hps.direct_iterator:
        iterator = iterator.get_next()

    xs = []
    zs = []
    for it in range(hps.full_test_its):
        if hps.direct_iterator:
            # replace with x, y, attr if you're getting CelebA attributes, also modify get_data
            x, y = sess.run(iterator)
        else:
            x, y = iterator()

        z = model.encode(x, y)
        x = model.decode(y, z)
        xs.append(x)
        zs.append(z)

    x = np.concatenate(xs, axis=0)
    z = np.concatenate(zs, axis=0)
    np.save('logs/x.npy', x)
    np.save('logs/z.npy', z)
    return zs

In [None]:
def train(sess, model, hps, logdir, visualise):
    _print(hps)
    _print('Starting training. Logging to', logdir)
    _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')

    # Train
    sess.graph.finalize()
    n_processed = 0
    n_images = 0
    train_time = 0.0
    test_loss_best = 999999
   
    train_logger = ResultLogger(logdir + "train.txt", hps)
    test_logger = ResultLogger(logdir + "test.txt", hps)

    
    tcurr = time.time()
    for epoch in range(1, hps.epochs):

        t = time.time()

        train_results = []
        for it in range(hps.train_its):

            # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
            lr = hps.lr * min(1., n_processed /
                              (hps.n_train * hps.epochs_warmup))

            # Run a training step synchronously.
            _t = time.time()
            train_results += [model.train(lr)]
            
            _print(n_processed, time.time()-_t, train_results[-1])
            sys.stdout.flush()

            # Images seen wrt anchor resolution
            n_processed += hps.n_batch_train
            # Actual images seen at current resolution
            n_images += hps.local_batch_train

        train_results = np.mean(np.asarray(train_results), axis=0)

        dtrain = time.time() - t
        ips = (hps.train_its * hps.local_batch_train) / dtrain
        train_time += dtrain

        if hvd.rank == 0:
            train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
                train_time), **process_results(train_results))

        if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
            test_results = []
            msg = ''

            t = time.time()
            # model.polyak_swap()

            if epoch % hps.epochs_full_valid == 0:
                # Full validation run
                for it in range(hps.full_test_its):
                    test_results += [model.test()]
                test_results = np.mean(np.asarray(test_results), axis=0)

                if hvd.rank == 0:
                    test_logger.log(epoch=epoch, n_processed=n_processed,
                                    n_images=n_images, **process_results(test_results))

                    # Save checkpoint
                    if test_results[0] < test_loss_best:
                        test_loss_best = test_results[0]
                        model.save(logdir+"model_best_loss.ckpt")
                        msg += ' *'

            dtest = time.time() - t

            # Sample
            t = time.time()
            if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
                visualise(epoch)
            dsample = time.time() - t

            dcurr = time.time() - tcurr
            tcurr = time.time()
            _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
                ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg)

            # model.polyak_swap()

    
    _print("Finished!")

In [None]:
def process_results(results):
    stats = ['loss', 'bits_x', 'bits_y', 'pred_loss']
    assert len(stats) == results.shape[0]
    res_dict = {}
    for i in range(len(stats)):
        res_dict[stats[i]] = "{:.4f}".format(results[i])
    return res_dict

In [None]:
def _print(*args, **kwargs):
    print(*args, **kwargs)

In [None]:
def init_visualizations(hps, model, logdir):

    def sample_batch(y, eps):
        n_batch = hps.local_batch_train
        xs = []
        for i in range(int(np.ceil(len(eps) / n_batch))):
            xs.append(model.sample(
                y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch]))
        return np.concatenate(xs)

    def draw_samples(epoch):
        if hvd.rank != 0:
            return

        rows = 10 if hps.image_size <= 64 else 4
        cols = rows
        n_batch = rows*cols
        y = np.asarray([_y % hps.n_y for _y in (
            list(range(cols)) * rows)], dtype='int32')

        # temperatures = [0., .25, .5, .626, .75, .875, 1.] #previously
        temperatures = [0., .25, .5, .6, .7, .8, .9, 1.]

        x_samples = []
        x_samples.append(sample_batch(y, [.0]*n_batch))
        x_samples.append(sample_batch(y, [.25]*n_batch))
        x_samples.append(sample_batch(y, [.5]*n_batch))
        x_samples.append(sample_batch(y, [.6]*n_batch))
        x_samples.append(sample_batch(y, [.7]*n_batch))
        x_samples.append(sample_batch(y, [.8]*n_batch))
        x_samples.append(sample_batch(y, [.9]*n_batch))
        x_samples.append(sample_batch(y, [1.]*n_batch))
        # previously: 0, .25, .5, .625, .75, .875, 1.

        for i in range(len(x_samples)):
            x_sample = np.reshape(
                x_samples[i], (n_batch, hps.image_size, hps.image_size, 3))
            graphics.save_raster(x_sample, logdir +
                                 'epoch_{}_sample_{}.png'.format(epoch, i))

    return draw_samples

In [None]:
learn = tf.contrib.learn
# tf.enable_eager_execution()

# Surpress verbose warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

sess = tensorflow_session()
sess.run(tf.global_variables_initializer())
seed = hps.seed
np.random.seed(seed)
tf.set_random_seed(seed)

image_size = hps.image_size
n_train = hps.n_train
n_test = hps.n_test
n_y = hps.n_y

problem = hps.problem
dal = hps.dal
n_levels = hps.n_levels
depth = hps.depth
lr = hps.lr
logdir = hps.logdir

train_iterator, test_iterator, data_init = get_data(hps, sess)

hps.train_its, hps.test_its, hps.full_test_its = get_its(hps)

# Create log dir
logdir = os.path.abspath(hps.logdir) + "/"
if not os.path.exists(logdir):
    os.mkdir(logdir)

import model
model = model.model(sess, hps, train_iterator, test_iterator, data_init)

# Initialize visualization functions
visualise = init_visualizations(hps, model, logdir)

if not hps.inference:
    # Perform training
    train(sess, model, hps, logdir, visualise)
else:
    infer(sess, model, hps, test_iterator)
#     logp_ = model.logp
    print(logp_)