In [20]:
import numpy as np
from tqdm import trange
import tensorflow as tf

# from utils import *
# from network import Network
# from statistic import Statistic

tf.set_random_seed(123)
np.random.seed(123)

In [2]:
import pickle

In [8]:
with open('mnist-hw1.pkl', 'rb') as f:
    data = pickle.load(f)

In [17]:
data_trn, data_val = data['train'], data['test']
print(data_trn.shape, data_val.shape)

(60000, 28, 28, 3) (10000, 28, 28, 3)


In [46]:
def conv2d(
    layer_in,
    output_dim,
    kernel_shape, # [kernel_height, kernel_width]
    mask_type, # None, "A" or "B",
    scope, 
    strides=[1, 1], # [column_wise_stride, row_wise_stride]
    activation_fn=None,
    weights_initializer=tf.contrib.layers.xavier_initializer(),
    weights_regularizer=None):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        mask_type = mask_type.lower()
        batch_size, height, width, channel = layer_in.get_shape().as_list()
        kernel_h, kernel_w = kernel_shape
        stride_h, stride_w = strides

        assert kernel_h % 2 == 1 and kernel_w % 2 == 1

        center_h = kernel_h // 2
        center_w = kernel_w // 2

        weights = tf.get_variable("weights", [kernel_h, kernel_w, channel, output_dim],
                                  tf.float32, weights_initializer, weights_regularizer)

        if mask_type is not None:
            mask = np.ones((kernel_h, kernel_w, channel, output_dim), dtype=np.float32)

            mask[center_h, center_w+1: ,: ,:] = 0.
            mask[center_h+1:, :, :, :] = 0.

            if mask_type == 'a':
                mask[center_h,center_w,:,:] = 0.

            weights *= tf.constant(mask, dtype=tf.float32)

        layer_out = tf.nn.conv2d(layer_in, weights, [1, stride_h, stride_w, 1], 
                                 padding='SAME', name='layer_in_at_weights')
        bias = tf.get_variable("bias", [output_dim,], tf.float32, tf.zeros_initializer())
        layer_out = tf.nn.bias_add(layer_out, bias, name='layer_in_at_weights_plut_bias')

        if activation_fn is not None:
            layer_out = activation_fn(layer_out, name='layer_out_activated')

    return layer_out

In [47]:
class Network():
    def __init__(self, sess, hidden_dim=16, out_hidden_dim=32, recurrent_length=7, out_recurrent_length=2, 
                 input_shape=[28, 28, 3], learning_rate=1e-3, grad_clip=1):

        self.sess = sess
        self.input = tf.placeholder(tf.float32, [None] + input_shape, name="input")

        # input of main reccurent layers
        nn = conv2d(self.input, output_dim=hidden_dim, kernel_shape=[7, 7], mask_type="A", scope="conv_in")
        self.hidden_layers = [nn]
        for idx in range(recurrent_length):
            nn = conv2d(nn, output_dim=3, kernel_shape=[1, 1], mask_type="B", scope="conv_hidden"+str(idx))
            self.hidden_layers.append(nn)
        
        self.output_layers = []
        for idx in range(out_recurrent_length):
            nn = conv2d(nn, output_dim=out_hidden_dim, kernel_shape=[1, 1], mask_type="B", scope="conv_out"+str(idx))
            nn = tf.nn.relu(nn)
            self.output_layers.append(nn)

        self.logits = conv2d(nn, output_dim=1, kernel_shape=[1, 1], mask_type="B", scope="conv_logits")
        self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, 
                                                                           labels=self.input, name='loss'))


        COLOR_DIM = 4
        self.logits = conv2d(nn, output_dim=COLOR_DIM, kernel_shape=[1, 1], mask_type="B", scope="conv_logits")
        self.l
        
#       COLOR_DIM = 4

#       self.l['conv2d_out_logits'] = conv2d(l_hid, COLOR_DIM, [1, 1], "B", scope='conv2d_out_logits')

#       self.l['conv2d_out_logits_flat'] = tf.reshape(
#           self.l['conv2d_out_logits'], [-1, self.height * self.width, COLOR_DIM])
#       self.l['normalized_inputs_flat'] = tf.reshape(
#           self.l['normalized_inputs'], [-1, self.height * self.width, COLOR_DIM])

#       # FIXED pre-1.0 # pred_pixels = [tf.squeeze(pixel, squeeze_dims=[1])
#       pred_pixels = [tf.squeeze(pixel, axis=[1])
#           # FIXED pre-1.0 # for pixel in tf.split(1, self.height * self.width, self.l['conv2d_out_logits_flat'])]
#           for pixel in tf.split(self.l['conv2d_out_logits_flat'], self.height * self.width, 1)]
#       # FIXED pre-1.0 # target_pixels = [tf.squeeze(pixel, squeeze_dims=[1])
#       target_pixels = [tf.squeeze(pixel, axis=[1])
#           # FIXED pre-1.0 # for pixel in tf.split(1, self.height * self.width, self.l['normalized_inputs_flat'])]
#           for pixel in tf.split(self.l['normalized_inputs_flat'], self.height * self.width, 1)]

#       softmaxed_pixels = [tf.nn.softmax(pixel) for pixel in pred_pixels]

#       losses = [tf.nn.sampled_softmax_loss(
#           pred_pixel, tf.zeros_like(pred_pixel), pred_pixel, target_pixel, 1, COLOR_DIM) \
#               for pred_pixel, target_pixel in zip(pred_pixels, target_pixels)]

#       self.l['output'] = tf.nn.softmax(self.l['conv2d_out_logits'])

#       logger.info("Building loss and optims")
#       # FIXED pre-1.0
#       # self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
#       #     self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss'))
#       self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
#           logits=self.l['conv2d_out_logits'], labels=self.l['normalized_inputs'], name='loss'))

        optimizer = tf.train.RMSPropOptimizer(learning_rate)
        grads_and_vars = optimizer.compute_gradients(self.loss)

        new_grads_and_vars = \
            [(tf.clip_by_value(gv[0], -grad_clip, grad_clip), gv[1]) for gv in grads_and_vars]
        self.op = optimizer.apply_gradients(new_grads_and_vars)

#   def predict(self, images):
#     return self.sess.run(self.l['output'], {self.l['inputs']: images})

    def step(self, batch, with_update=False):
        if with_update:
            _, loss = self.sess.run([self.op, self.loss], feed_dict={self.input: batch})
        else:
            loss = self.sess.run(self.loss, feed_dict={self.input: batch})
        return loss

#   def generate(self):
#     samples = np.zeros((100, self.height, self.width, 1), dtype='float32')

#     for i in xrange(self.height):
#       for j in xrange(self.width):
#         for k in xrange(self.channel):
#           next_sample = binarize(self.predict(samples))
#           samples[:, i, j, k] = next_sample[:, i, j, k]

#           if self.data == 'mnist':
#             print "=" * (self.width/2), "(%2d, %2d)" % (i, j), "=" * (self.width/2)
#             mprint(next_sample[0,:,:,:])

#     return samples

In [48]:
def binarize(images):
    return (np.random.uniform(size=images.shape)*3 < images).astype('float32')

In [49]:
def train(data_trn, data_val, batch_size=50, num_epochs=1000, log_per_epoch=50):
    with tf.Session() as sess:
        network = Network(sess)

        iterator = trange(num_epochs, ncols=70, initial=0)
        loss_trn = []
        loss_val = []

        for epoch in iterator:
            loss_trn_batch = []
            for batch in np.array_split(data_trn, np.ceil(len(data_trn)/batch_size)):
                loss = network.step(binarize(batch), with_update=True)
                loss_trn_batch.append(loss)

            if epoch % log_per_epoch == 0:
                loss_trn.append(np.mean(loss_trn_batch))
                loss_val.append(network.step(data_val, with_update=False))
    return loss_trn, loss_val

In [50]:
loss_trn, loss_val = train(data_trn, data_val)



ValueError: logits and labels must have the same shape ((?, 28, 28, 1) vs (?, 28, 28, 3))

In [53]:
data_trn[0]

array([[[3, 2, 2],
        [3, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]],

       [[3, 2, 2],
        [3, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]],

       [[2, 2, 2],
        [2, 2, 2],
        [3, 2, 3],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]],

       ...,

       [[3, 2, 2],
        [3, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 3],
        [3, 2, 3],
        [3, 2, 3]],

       [[2, 2, 2],
        [2, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 3],
        [3, 2, 3]],

       [[2, 2, 2],
        [2, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]]], dtype=uint8)