In [1]:
# imports
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.mlab as mlab
import math
import time
import pickle
%matplotlib inline

  from ._conv import register_converters as _register_converters


In [2]:
# Initializers

def actnorm_layer(tf_in, op_init_actnorm, name="", reverse=False):    
    ones_init = 5
    zeros_init = tf.zeros_initializer()
    ones_init = tf.ones_initializer()

    with tf.variable_scope(name+'act_norm', reuse=tf.AUTO_REUSE):
        # Actual actnorm
        #print(tf_in)
        channels = tf_in.shape[-1]
        tf_log_scale = tf.get_variable("scale", shape=(channels), initializer=ones_init)
        tf_scale = tf.exp(tf_log_scale)
        tf_bias  = tf.get_variable("bias" , shape=(channels), initializer=zeros_init)
        
        if not reverse:  
            tf_out = tf_in
            tf_out *= tf_scale
            tf_out += tf_bias

            # Actnorm initialization
            tf_mean, tf_variance = tf.nn.moments(tf_in, axes=[0, 1, 2])
            tf_new_scale = tf_scale / tf.math.sqrt(tf_variance)
            tf_new_bias = tf_bias - tf_mean * tf_new_scale
            op_init_actnorm.append([
                tf_log_scale.assign(tf.log(tf_new_scale)),
                tf_bias.assign(tf_new_bias)
            ])

            tf_log_jacobian_determinant = tf_in.shape[1].value * tf_in.shape[2].value * tf.reduce_sum(tf_log_scale)
            return tf_out, tf_log_jacobian_determinant
        
        else:
            tf_inverse_out = tf_in
            tf_inverse_out = tf_inverse_out - tf_bias
            tf_inverse_out = tf_inverse_out / tf_scale

            return tf_inverse_out
        
def glow_coupling_layer(tf_in, op_init_actnorm=None, reverse=False):
    CHANNEL_WIDTH = 512
    
    with tf.variable_scope("glow_module", reuse=tf.AUTO_REUSE):
        if not reverse: 
            tf_in, log_det1 = actnorm_layer(tf_in, op_init_actnorm)
            tf_in, log_det2 = invertible_1x1_conv(tf_in)
            tf_in, log_det3 = coupling_layer(tf_in, CHANNEL_WIDTH)
            #print(log_det1.get_shape(), log_det2.get_shape(), log_det3.get_shape())
            return tf_in, log_det1 + log_det2 + log_det3
            #return tf_in,  log_det2 + log_det3
        else:
            tf_in = coupling_layer(tf_in, CHANNEL_WIDTH, reverse=True)
            tf_in = invertible_1x1_conv(tf_in, reverse=True)
            tf_in = actnorm_layer(tf_in, op_init_actnorm, reverse=True)
            return tf_in

def squeeze_layer(tf_in, reverse=False):
    # [b, h, w, c] => [b, h // 2, w // 2, c * 4]
    if not reverse:
        tf_out = tf.nn.space_to_depth(tf_in, 2)
    else:
        tf_out = tf.nn.depth_to_space(tf_in, 2)
    return tf_out
        
def preprocess_layer(tf_in, reverse=False):
    alpha = 0.05
    max_pixel_value = 256
    shape = int_shape(tf_in)
    DIM = shape[1] * shape[2] * shape[3]
    # initial scaling
    # (1.0 - 2 * alpha) makes more sense than
    # (1 - alpha) to me... TODO check???
    
    if not reverse:        
        # normalization
        tf_x = (alpha + (1.0 - 2.0*alpha) * (tf_in) / max_pixel_value)
        # compute logit function
        tf_out = tf.log(tf_x / (1 - tf_x) + 1e-10)

        # determinant of logit function
        tf_log_jacobian_determinant = tf.reduce_sum(
            tf.reshape(
                tf.log((1.0 / tf_x) + (1.0 / (1.0 - tf_x))),
                (-1, DIM)
            ),
            axis=-1
        )
        # determinant of normalization
        tf_log_jacobian_determinant += np.log((1.0 - 2.0*alpha) / max_pixel_value) * DIM
        
        return tf_out, tf_log_jacobian_determinant

    else:
        tf_out = tf_in
        # inverse logit
        tf_out = tf.math.sigmoid(tf_out)
        # inverse normalization
        tf_out = (tf_out - alpha) * max_pixel_value / (1.0 - 2.0 * alpha)
        
        return tf_out


In [3]:
# copy pasted from samathas code for coupling
def default_initializer(std=0.05):
    return tf.random_normal_initializer(0., std)

def flatten_sum(logps):
    if len(logps.get_shape()) == 2:
        return tf.reduce_sum(logps, [1])
    elif len(logps.get_shape()) == 4:
        return tf.reduce_sum(logps, [1, 2, 3])
    else:
        raise Exception()
        
def int_shape(x):
    if str(x.get_shape()[0]) != '?':
        return list(map(int, x.get_shape()))
    return [-1]+list(map(int, x.get_shape()[1:]))

def add_edge_padding(x, filter_size):
    assert filter_size[0] % 2 == 1
    if filter_size[0] == 1 and filter_size[1] == 1:
        return x
    a = (filter_size[0] - 1) // 2  # vertical padding size
    b = (filter_size[1] - 1) // 2  # horizontal padding size
    if True:
        x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
        name = "_".join([str(dim) for dim in [a, b, *int_shape(x)[1:3]]])
        pads = tf.get_collection(name)
        if not pads:
            if False: #if hvd.rank() == 0:
                print("Creating pad", name)
            pad = np.zeros([1] + int_shape(x)[1:3] + [1], dtype='float32')
            pad[:, :a, :, 0] = 1.
            pad[:, -a:, :, 0] = 1.
            pad[:, :, :b, 0] = 1.
            pad[:, :, -b:, 0] = 1.
            pad = tf.convert_to_tensor(pad)
            tf.add_to_collection(name, pad)
        else:
            pad = pads[0]
        pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1])
        x = tf.concat([x, pad], axis=3)
    else:
        pad = tf.pad(tf.zeros_like(x[:, :, :, :1]) - 1,
                     [[0, 0], [a, a], [b, b], [0, 0]]) + 1
        x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
        x = tf.concat([x, pad], axis=3)
    return x

def Z_conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", do_weightnorm=False, do_actnorm=True, context1d=None, skip=1, edge_bias=True):
    with tf.variable_scope(name):
        if edge_bias and pad == "SAME":
            x = add_edge_padding(x, filter_size)
            pad = 'VALID'

        n_in = int(x.get_shape()[3])

        stride_shape = [1] + stride + [1]
        filter_shape = filter_size + [n_in, width]
        w = tf.get_variable("W", filter_shape, tf.float32,
                            initializer=default_initializer())
        if do_weightnorm:
            w = tf.nn.l2_normalize(w, [0, 1, 2])
        if skip == 1:
            x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
        else:
            assert stride[0] == 1 and stride[1] == 1
            x = tf.nn.atrous_conv2d(x, w, skip, pad)
        if do_actnorm:
            x = actnorm("actnorm", x)
        else:
            x += tf.get_variable("b", [1, 1, 1, width],
                                 initializer=tf.zeros_initializer())

        if context1d != None:
            x += tf.reshape(linear("context", context1d,
                                   width), [-1, 1, 1, width])
    return x

def Z_conv2d_zeros(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", logscale_factor=3, skip=1, edge_bias=True):
    with tf.variable_scope(name):
        if edge_bias and pad == "SAME":
            x = add_edge_padding(x, filter_size)
            pad = 'VALID'

        n_in = int(x.get_shape()[3])
        stride_shape = [1] + stride + [1]
        filter_shape = filter_size + [n_in, width]
        w = tf.get_variable("W", filter_shape, tf.float32,
                            initializer=tf.zeros_initializer())
        if skip == 1:
            x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
        else:
            assert stride[0] == 1 and stride[1] == 1
            x = tf.nn.atrous_conv2d(x, w, skip, pad)
        x += tf.get_variable("b", [1, 1, 1, width],
                             initializer=tf.zeros_initializer())
        x *= tf.exp(tf.get_variable("logs",
                                    [1, width], initializer=tf.zeros_initializer()) * logscale_factor)
    return x



In [4]:
def coupling_layer(tf_in, WIDTH, name="", reverse=False):
    with tf.variable_scope(name+'coupling_layer', reuse=tf.AUTO_REUSE):
        def coupling_network(tf_layer, reuse=False):
            # From GLOW:
            def f(name, h, width, n_out=None):
                n_out = n_out or int(h.get_shape()[3])
                with tf.variable_scope(name, reuse=reuse):
                    # NOTE: adding do_actnorm = False because... actnorm_layer would add something to inverse_ops and I don't wanna think about that.
                    h = tf.nn.relu(Z_conv2d("l_1", h, width, do_actnorm=False))
                    h = tf.nn.relu(Z_conv2d("l_2", h, width, filter_size=[1, 1], do_actnorm=False))
                    h = Z_conv2d_zeros("l_last", h, n_out)
                return h
            h = f("f1", tf_layer, WIDTH, n_z)
            shift = h[:, :, :, 0::2]
            scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
            return shift, scale
        
        shape = int_shape(tf_in)
        n_z = shape[3]
        assert n_z % 2 == 0, shape
        z1 = tf_in[:, :, :, :n_z // 2]
        z2 = tf_in[:, :, :, n_z // 2:]
        shift, scale = coupling_network(z1, reuse=False)     
        
        
        if not reverse: 
            z2 += shift
            z2 *= scale
            tf_log_jacobian_determinant = tf.reduce_sum(tf.log(scale), axis=[1,2,3])
            z = tf.concat([z1, z2], 3)
            return z, tf_log_jacobian_determinant
    
        else: ## Inverse operation
            z2 /= scale
            z2 -= shift

            tf_in = tf.concat([z1, z2], 3)
            return tf_in

In [5]:
def invertible_1x1_conv(z, reverse=False):
    # Copied wholesale from glow
    with tf.variable_scope("invertable_1x1", reuse=tf.AUTO_REUSE):
        shape = int_shape(z)
        w_shape = [shape[3], shape[3]]
        
        # Sample a random orthogonal matrix:
        w_init = np.linalg.qr(np.random.randn(
            *w_shape))[0].astype('float32')
        w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)
        
        dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
            tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]

        if not reverse:
            _w = tf.reshape(w, [1, 1] + w_shape)
            z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
                             'SAME', data_format='NHWC')
            return z, dlogdet

        else:
            _w = tf.matrix_inverse(w)
            _w = tf.reshape(_w, [1, 1]+w_shape)
            z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
                             'SAME', data_format='NHWC')
            return z

In [1]:
def split(z):
    n_z = int_shape(z)[3]
    assert n_z % 2 == 0
    z_main  = z[:, :, :, :n_z // 2]
    z_other = z[:, :, :, n_z // 2:]
    return z_main, z_other

def split_reverse(z1, z2):
    return tf.concat([z1, z2], axis=-1)

def get_vectorize_shape(z):
    shape = int_shape(z)
    dim = shape[1] * shape[2] * shape[3] 
    return dim