In [20]:
import numpy as np
import tensorflow as tf
import pickle
from tqdm import trange

tf.set_random_seed(123)
np.random.seed(123)

### Data Preparation

In [8]:
with open('mnist-hw1.pkl', 'rb') as f:
    data = pickle.load(f)

In [17]:
data_trn, data_val = data['train'], data['test']
print(data_trn.shape, data_val.shape)

(60000, 28, 28, 3) (10000, 28, 28, 3)


### Utils

In [48]:
def binarize(images):
    return (np.random.uniform(size=images.shape)*3 < images).astype('float32')

In [86]:
def conv2d(
    layer_in,
    output_dim,
    kernel_shape, # [kernel_height, kernel_width]
    mask_type, # None, "A" or "B",
    scope, 
    strides=[1, 1], # [column_wise_stride, row_wise_stride]
    activation_fn=None):
    with tf.variable_scope(scope):
        mask_type = mask_type.lower()
        batch_size, height, width, channel = layer_in.get_shape().as_list()
        print("building within scope", scope, batch_size, height, width, channel)
        kernel_h, kernel_w = kernel_shape
        stride_h, stride_w = strides

        assert kernel_h % 2 == 1 and kernel_w % 2 == 1

        center_h = kernel_h // 2
        center_w = kernel_w // 2

        weights = tf.get_variable("weights", [kernel_h, kernel_w, channel, output_dim],
                                  tf.float32, tf.contrib.layers.xavier_initializer())

        if mask_type is not None:
            mask = np.ones((kernel_h, kernel_w, channel, output_dim), dtype=np.float32)

            mask[center_h, center_w+1:, :, :] = 0.
            mask[center_h+1:, :, :, :] = 0.

            if mask_type == 'a':
                mask[center_h, center_w, :, :] = 0.

            weights.assign(weights * tf.constant(mask, dtype=tf.float32))

        layer_out = tf.nn.conv2d(input=layer_in, filter=weights, strides=[1, stride_h, stride_w, 1], 
                                 padding='SAME', name='layer_in_at_weights')
        biases = tf.get_variable("biases", [output_dim,], tf.float32, tf.zeros_initializer())
        layer_out = tf.nn.bias_add(layer_out, biases, name='layer_in_at_weights_plus_biases')

        if activation_fn is not None:
            layer_out = activation_fn(layer_out, name='layer_out_activated')

    return layer_out

### Architecture

In [117]:
class PixelCNN():
    def __init__(self, sess, color_dim=4, hidden_dim=16, out_hidden_dim=32, recurrent_length=12, out_recurrent_length=2, 
                 input_shape=[28, 28, 3], learning_rate=1e-3, grad_clip=1):

        self.sess = sess
        self.height, self.width, self.num_channels = input_shape
        self.input = tf.placeholder(tf.float32, [None] + input_shape, name="input")
#         self.input_one_hot = tf.one_hot(tf.cast(self.input, tf.int32), color_dim, axis=-1)
#         self.input_one_hot_flatten = tf.reshape(self.input_one_hot, [-1, self.height*self.width, color_dim]) #???

        '''
        build layers
        '''
        nn = conv2d(self.input, output_dim=hidden_dim, kernel_shape=[7, 7], mask_type="A", scope="conv_in")
        self.hidden_layers = [nn]
        for idx in range(recurrent_length):
            nn = conv2d(nn, output_dim=3, kernel_shape=[1, 1], mask_type="B", scope="conv_hidden"+str(idx))
            self.hidden_layers.append(nn)
        
        self.output_layers = []
        for idx in range(out_recurrent_length):
            nn = conv2d(nn, output_dim=out_hidden_dim, kernel_shape=[1, 1], mask_type="B", scope="conv_out"+str(idx))
            nn = tf.nn.relu(nn)
            self.output_layers.append(nn)

#         self.logits = conv2d(nn, output_dim=color_dim, kernel_shape=[1, 1], mask_type="B", scope="conv_logits")
#         self.input_flattened = tf.reshape(self.input, [-1, self.height*self.width, color_dim])
#         print("input_flattened has shape", tf.shape(self.input_flattened))
#         target_pixels = [tf.squeeze(pixel, axis=[1]) 
#                          for pixel in tf.split(self.input_flattened, self.height*self.width, 1)]
#         print("target_pixel has shape", tf.shape(target_pixels[0]))
        
#         self.logits_flattened = tf.reshape(self.logits, [-1, self.height*self.width, color_dim])
#         print("logits_flattened has shape", tf.shape(self.logits_flattened))
#         pred_pixels = [tf.squeeze(pixel, axis=[1]) 
#                        for pixel in tf.split(self.logits_flattened, self.height*self.width, 1)]
#         print("pred_pixel has shape", tf.shape(pred_pixels[0]))
#         losses = [tf.nn.sampled_softmax_loss(pred_pixel, tf.zeros_like(pred_pixel), pred_pixel, target_pixel, 1, color_dim)
#                   for pred_pixel, target_pixel in zip(pred_pixels, target_pixels)]
        '''
        compute loss
        
        self.input = [None, 28, 28, 3]
        self.input_flattened = [None, 28*28*3, 1]
        target_pixel = [None, 1, 1] -> [None, 1]
        
        self.logits = [None, 28, 28, 3*4]
        self.logits_flattened = [None, 28*28*3, 4]
        pred_pixel = [None, 1, 4] -> [None, 4]
        
        num_classes = color_dim = 4, dim = ?, num_true = 1
        weights = [4, ?]
        biases = [4,]
        labels = [None, 1]
        inputs = [None, ?]
        '''
        dim = 8
        self.logits = conv2d(nn, output_dim=self.num_channels*dim, kernel_shape=[1, 1], mask_type="B", scope="conv_logits")
        
        self.input_flattened = tf.expand_dims(tf.squeeze(self.input), 1)
        self.logits_flattened = tf.reshape(self.logits, [-1, dim])
        self.loss = tf.nn.sampled_softmax_loss(
            weights=tf.Variable(tf.random.normal([color_dim, dim]), name='softmax_weights'), 
            biases=tf.zeros([color_dim,]), 
            labels=self.input_flattened, 
            inputs=self.logits_flattened, 
            num_sampled=1, 
            num_classes=color_dim, 
            name="sampled_softmax_loss")
        

#         self.input_flattened = tf.reshape(self.input, [-1, self.height*self.width*self.num_channels, 1])
#         target_pixels = [tf.squeeze(pixel, axis=[1]) 
#                          for pixel in tf.split(self.input_flattened, self.height*self.width*self.num_channels, 1)]
#         print("target_pixel has shape", tf.shape(target_pixels[0]))
        
#         self.logits_flattened = tf.reshape(self.logits, [-1, self.height*self.width*self.num_channels, color_dim])
#         pred_pixels = [tf.squeeze(pixel, axis=[1]) 
#                        for pixel in tf.split(self.logits_flattened, self.height*self.width, 1)]

#         self.w = tf.Variable()
#         losses = [tf.nn.sampled_softmax_loss(weights=pred_pixel, biases=tf.zeros_like(pred_pixel), 
#                                              labels=target_pixel, inputs=pred_pixel, num_sampled=1, 
#                                              num_classes=color_dim, partition_strategy='mod', 
#                                              name="sampled_softmax_loss") \
#                   for pred_pixel, target_pixel in zip(pred_pixels, target_pixels)]

#         self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, 
#                                                                            labels=self.input_one_hot, name='loss'))
        optimizer = tf.train.RMSPropOptimizer(learning_rate)
        grads_and_vars = optimizer.compute_gradients(self.loss)

        new_grads_and_vars = \
            [(tf.clip_by_value(gv[0], -grad_clip, grad_clip), gv[1]) for gv in grads_and_vars]
        self.op = optimizer.apply_gradients(new_grads_and_vars)
        
    def step(self, batch, with_update=False):
        if with_update:
            _, loss = self.sess.run([self.op, self.loss], feed_dict={self.input: batch})
        else:
            loss = self.sess.run(self.loss, feed_dict={self.input: batch})
        return loss

In [123]:
def train(data_trn, data_val, batch_size=128, num_epochs=50, log_per_epoch=1):
    with tf.Session() as sess:
        sess.run(tf.initializers.global_variables())

        network = PixelCNN(sess)

        iterator = trange(num_epochs, ncols=70, initial=0)
        loss_trn = []
        loss_val = []

        for epoch in iterator:
            loss_trn_batch = []
            for batch in np.array_split(data_trn, np.ceil(len(data_trn)/batch_size)):
                loss = network.step(batch, with_update=True)
                loss_trn_batch.append(loss)

            if epoch % log_per_epoch == 0:
                loss_trn.append(np.mean(loss_trn_batch))
                loss_val.append(network.step(data_val, with_update=False))
    return loss_trn, loss_val

### Training

In [124]:
tf.reset_default_graph()
loss_trn, loss_val = train(data_trn, data_val)

building within scope conv_in None 28 28 3
building within scope conv_hidden0 None 28 28 16
building within scope conv_hidden1 None 28 28 3
building within scope conv_hidden2 None 28 28 3
building within scope conv_hidden3 None 28 28 3
building within scope conv_hidden4 None 28 28 3
building within scope conv_hidden5 None 28 28 3
building within scope conv_hidden6 None 28 28 3
building within scope conv_hidden7 None 28 28 3
building within scope conv_hidden8 None 28 28 3
building within scope conv_hidden9 None 28 28 3
building within scope conv_hidden10 None 28 28 3
building within scope conv_hidden11 None 28 28 3
building within scope conv_out0 None 28 28 3
building within scope conv_out1 None 28 28 32
building within scope conv_logits None 28 28 32






  0%|                                          | 0/50 [00:00<?, ?it/s][A[A[A[A

FailedPreconditionError: Attempting to use uninitialized value conv_logits/biases
	 [[node conv_logits/biases/read (defined at <ipython-input-86-5754a020ffc7>:37) ]]

Caused by op 'conv_logits/biases/read', defined at:
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 505, in start
    self.io_loop.start()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/asyncio/base_events.py", line 438, in run_forever
    self._run_once()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/asyncio/base_events.py", line 1451, in _run_once
    handle._run()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/gen.py", line 1233, in inner
    self.run()
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/gen.py", line 1147, in run
    yielded = self.gen.send(value)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 357, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 267, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 534, in execute_request
    user_expressions, allow_stdin,
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2819, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2845, in _run_cell
    return runner(coro)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/IPython/core/async_helpers.py", line 67, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3020, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3185, in run_ast_nodes
    if (yield from self.run_code(code, result)):
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3267, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-124-f9a2799cd377>", line 2, in <module>
    loss_trn, loss_val = train(data_trn, data_val)
  File "<ipython-input-123-a58429fc1ebb>", line 5, in train
    network = PixelCNN(sess)
  File "<ipython-input-117-21750c7cab3e>", line 58, in __init__
    self.logits = conv2d(nn, output_dim=self.num_channels*dim, kernel_shape=[1, 1], mask_type="B", scope="conv_logits")
  File "<ipython-input-86-5754a020ffc7>", line 37, in conv2d
    biases = tf.get_variable("biases", [output_dim,], tf.float32, tf.zeros_initializer())
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1479, in get_variable
    aggregation=aggregation)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1220, in get_variable
    aggregation=aggregation)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 547, in get_variable
    aggregation=aggregation)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 499, in _true_getter
    aggregation=aggregation)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 911, in _get_single_variable
    aggregation=aggregation)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 213, in __call__
    return cls._variable_v1_call(*args, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 176, in _variable_v1_call
    aggregation=aggregation)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 155, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 2495, in default_variable_creator
    expected_shape=expected_shape, import_scope=import_scope)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 217, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 1395, in __init__
    constraint=constraint)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 1557, in _init_from_args
    self._snapshot = array_ops.identity(self._variable, name="read")
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 81, in identity
    ret = gen_array_ops.identity(input, name=name)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3890, in identity
    "Identity", input=input, name=name)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
    op_def=op_def)
  File "/Users/ZhangYunzhi/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__
    self._traceback = tf_stack.extract_stack()

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value conv_logits/biases
	 [[node conv_logits/biases/read (defined at <ipython-input-86-5754a020ffc7>:37) ]]


In [53]:
data_trn[0]

array([[[3, 2, 2],
        [3, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]],

       [[3, 2, 2],
        [3, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]],

       [[2, 2, 2],
        [2, 2, 2],
        [3, 2, 3],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]],

       ...,

       [[3, 2, 2],
        [3, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 3],
        [3, 2, 3],
        [3, 2, 3]],

       [[2, 2, 2],
        [2, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 3],
        [3, 2, 3]],

       [[2, 2, 2],
        [2, 2, 2],
        [3, 2, 2],
        ...,
        [3, 2, 2],
        [3, 2, 2],
        [3, 2, 2]]], dtype=uint8)