In [122]:
import pickle
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

In [26]:
with open('hw2_q2.pkl', 'rb') as f:
    data = pickle.load(f)

In [27]:
data_trn, data_val = data['train'], data['test']
print(data_trn.shape, data_val.shape)

(20000, 32, 32, 3) (6838, 32, 32, 3)


In [129]:
def resnet(layer_in, num_channels, output_dim, num_filters=256, num_blocks=8, scope="resnet"):
    
#     TODO: BatchNorm & WeightNormalization

    with tf.variable_scope(scope):
        h = conv2d(layer_in, scope="conv2d", kernel=(3, 3), stride=(1, 1), 
                   in_channels=num_channels, out_channels=num_filters)
        for idx in range(num_blocks):
            _h = conv2d(h, scope="conv2d_"+str(idx)+"_0", kernel=(1, 1), stride=(1, 1), 
                        in_channels=num_filters, out_channels=num_filters)
            _h = tf.nn.relu(_h)
            _h = conv2d(_h, scope="conv2d_"+str(idx)+"_1", kernel=(3, 3), stride=(1, 1), 
                        in_channels=num_filters, out_channels=num_filters)
            _h = tf.nn.relu(_h)
            _h = conv2d(_h, scope="conv2d_"+str(idx)+"_2", kernel=(1,1), stride=(1, 1), 
                        in_channels=num_filters, out_channels=num_filters)
            h = h + _h
            h = tf.nn.relu(_h)
        layer_out = conv2d(h, scope="resnet_layer_out", kernel=(3, 3), stride=(1, 1), 
                           in_channels=num_filters, out_channels=output_dim)
    return layer_out
        
def conv2d(layer_in, scope, kernel, stride, in_channels, out_channels):
    with tf.variable_scope(scope):
        kernel_h, kernel_w = kernel
        stride_h, stride_w = stride
        weights = tf.get_variable("weights", [kernel_h, kernel_w, in_channels, out_channels],
                                  tf.float32, tf.contrib.layers.xavier_initializer())
        layer_out = tf.nn.conv2d(input=layer_in, filter=weights, strides=[1, stride_h, stride_w, 1], 
                                 padding='SAME', name='conv2d_layer_out')
    return layer_out

$$
y=b \odot x+(1-b) \odot(x \odot \exp (s(b \odot x))+t(b \odot x))
$$

In [130]:
class Layer():
    def x2y(self, x, sum_log_jacobian):
        raise NotImplementedError()

    def y2x(self, y, z):
        raise NotImplementedError()

In [131]:
class CouplingLayer(Layer):
    def __init__(self, scope, mask_type):
        self.scope = scope
        self.mask_type = mask_type
        
    def _get_mask(self, shape):
        if self.mask_type.startswith("checkerboard"):
            if self.mask_type == "checkerboard0":
                mask = tf.constant([[0.0, 1.0], [1.0, 0.0]], dtype=tf.float32)
            elif self.mask_type == "checkerboard1": 
                mask = tf.constant([[1.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
            mask = tf.reshape(mask, [1, 2, 2, 1], name="mask_" + self.mask_type)
            shape = [shape[0], shape[1]//2, shape[2]//2, shape[3]]
            mask = tf.tile(mask, shape)
        elif self.mask_type.startswith("channel"):
            shape = [shape[0], shape[1], shape[2], shape[3]//2]
            ones = tf.ones(shape)
            zeros = tf.zeros(shape)
            if self.mask_type == "channel0":
                mask = tf.concat([ones, zeros], axis=-1, name="mask_" + self.mask_type)
            elif self.mask_type == "channel1": 
                mask = tf.concat([zeros, ones], axis=-1, name="mask_" + self.mask_type)
        return mask
    
    def _build_log_s_t(self, masked_in, num_channels, output_dim, scope="_build_log_s_t"):
        with tf.variable_scope(scope):
            resnet_out = resnet(masked_in, num_channels, output_dim)
            log_s, t = tf.split(resnet_out, 2, axis=-1)
        return log_s, t
    
    def x2y(self, x, sum_log_jacobian):
        '''
        sum_log_jacobian = (None,)
        '''
        with tf.variable_scope(self.scope):
            mask = self._get_mask(tf.shape(x))
            masked_x = mask * x
            num_channels = x.get_shape()[-1]
            log_s, t = self._build_log_s_t(masked_x, num_channels, num_channels*2)
            s = tf.check_numerics(tf.exp(log_s), "exp has NaN")
            y = masked_x + (1 - mask) * (x * s + t)
            sum_log_jacobian += tf.reduce_sum(log_s, axis=[1, 2, 3])
        return y, sum_log_jacobian
        
    def y2x(self, y, z):
        with tf.variable_scope(self.scope, reuse=True):
            mask = self._get_mask(tf.shape(y))
            masked_y = mask * y
            log_s, t = self._build_log_s_t(mask_y)
            neg_s = tf.check_numerics(tf.exp(-log_s), "exp has NaN")
            x = masked_y + ((1 - mask) * y - t) * tf.exp(neg_s)
            
        return x, z

In [132]:
class SqueezingLayer(Layer):
    def __init__(self):
        pass

    def x2y(self, x, sum_log_det_jacobians):
        y = tf.space_to_depth(x, 2)
        return y,sum_log_det_jacobians
    
    def y2x(self, y, z):
        x = tf.depth_to_space(y, 2)
        if z is not None:
            z = tf.depth_to_space(z, 2)
        return x, z

In [None]:
class RealNVP():
    def __init__(self, sess, input_shape=(32, 32, 3), learning_rate=1e-4):
        self.sess = sess
        self.input_shape = input_shape
        self._build_ph()
        self._build_layers()
        self._build_loss()
        self._build_op(learning_rate)

    def _build_ph(self):
        self.x = tf.placeholder(tf.float32, (None,) + self.input_shape, name="x")
        self.z = tf.placeholder(tf.float32, (None,) + self.input_shape, name="z")
        
    def _build_layers(self):
        self.layers = []
        self.layers.extend([CouplingLayer(
            "0_checkerboard_"+str(idx)+"_1", "checkerboard" + str(idx % 2)
        ) for idx in range(4)])
        self.layers.append(SqueezingLayer())
        self.layers.extend([CouplingLayer(
            "1_channel" + str(idx) + "_2", "channel" + str(idx % 2)
        ) for idx in range(3)])
        self.layers.extend([CouplingLayer(
            "2_checkerboard_"+str(idx)+"_1", "checkerboard" + str(idx % 2)
        ) for idx in range(3)])
        self.layers.append(SqueezingLayer())
        self.layers.extend([CouplingLayer(
            "3_channel" + str(idx) + "_2", "channel" + str(idx % 2)
        ) for idx in range(3)])
        self.layers.extend([CouplingLayer(
            "4_checkerboard_"+str(idx)+"_1", "checkerboard" + str(idx % 2)
        ) for idx in range(3)])
        
    def _build_loss(self):
        y, sum_log_jacobian = self.x, 0
        for layer in self.layers:
            y, sum_log_jacobian = layer.x2y(y, sum_log_jacobian)
        
        base_dist = tfp.distributions.MultivariateNormalDiag(
            loc=[0.0, 0.0], 
            scale_diag=[1.0, 1.0])
        self.loss = tf.reduce_sum(tf.log(base_dist.prob(y)) + sum_log_jacobian, axis=0)
        print(y.get_shape().as_list())
        print(tf.log(base_dist.prob(y).get_shape().as_list()))
        print(sum_log_jacobian.get_shape().as_list())
        print(self.loss.get_shape().as_list())
    
    def _build_op(self, learning_rate):
        self.op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)
        
    def _build_sample(self):
        z = self.z
        for layer in reversed(self.layers):
            y, z = layer.y2x(y, z)
        return y
            
    def step(self, batch, with_update=False):
        if with_update:
            loss, _ = self.sess.run([self.loss, self.op], feed_dict={self.x: batch})
        else:
            loss = self.sess.run(self.loss, feed_dict={self.x: batch})
        return loss
    
    def sample(self, batch):
        return self.sess.run(self.sample_out, feed_dict={self.z: batch})

In [None]:
def train(sess, data_trn, data_val, batch_size=256, num_epochs=2, 
          log_per_epoch=1, print_per_epoch=10):
    model = RealNVP(sess)
    init_op = tf.initializers.global_variables()
    sess.run(init_op)
    
    loss_trn = []
    loss_val = []

    for epoch in range(num_epochs):
        loss_trn_batch = []
        for batch in np.array_split(data_trn, np.ceil(len(data_trn)/batch_size)):
            loss = model.step(batch, with_update=True)
            loss_trn_batch.append(loss)

        if epoch % log_per_epoch == 0:
            loss_trn.append(np.mean(loss_trn_batch))
            loss_val.append(model.step(data_val, with_update=False))
        if epoch % print_per_epoch == 0:
            print("at epoch", epoch, loss_trn[-1], loss_val[-1])
    return loss_trn, loss_val, model

In [None]:
tf.reset_default_graph()
sess = tf.Session()
loss_trn, loss_val, network = train(sess, data_trn, data_val)