In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import scipy.misc
import os

slim = tf.contrib.slim

In [3]:
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
def conv2d_A(inputs,
             num_outputs,
             kernel_size,
             strides=[1,1],
             padding='SAME',
             activation_fn=None,
             weights_regularizer=None,
             biases_regularizer=None,
             scope="conv2d_A"):
    with tf.variable_scope(scope):
        batch_size, height, width, num_filters_in = inputs.get_shape().as_list()
        kernel_h, kernel_w = kernel_size
        stride_h, stride_w = strides
        
        weights_shape = [kernel_h, kernel_w, num_filters_in, num_outputs]
        weights = tf.get_variable('weights',shape=weights_shape,
                                            dtype=tf.float32,
                                            initializer=tf.contrib.layers.xavier_initializer(),
                                            regularizer=weights_regularizer)
        
        center_h = kernel_h // 2
        center_w = kernel_w // 2
        mask = np.ones(weights_shape, dtype=np.float32)
        mask[center_h, center_w+1: ,: ,:] = 0.
        mask[center_h+1:, :, :, :] = 0.
        mask[center_h,center_w,:,:] = 0.
        
        weights *= mask
        
        outputs = tf.nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],padding=padding)
        
        biases = tf.get_variable('biases',shape=[num_outputs,],
                                          dtype=tf.float32,
                                          initializer=tf.zeros_initializer,
                                          regularizer=biases_regularizer)
        outputs = tf.nn.bias_add(outputs, biases)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
            
        slim.add_model_variable(weights)
        slim.add_model_variable(biases)
        return outputs

In [4]:
batch_size, height, width, channel = 100, 28, 28, 1

with slim.arg_scope([slim.conv2d], activation_fn=None):
    images = tf.placeholder(tf.float32, [batch_size, height, width, channel])
    net = conv2d_A(images, 64, [7, 7])
    net = slim.repeat(net, 2, slim.conv2d, 3, [1, 1], scope='main_reccurent_layers')
    net = slim.repeat(net, 2, slim.conv2d, 64, [1, 1], scope='output_reccurent_layers',activation_fn=tf.nn.relu)
    logits = slim.conv2d(net, 1, [1, 1], scope='conv2d_out_logits')

outputs = tf.nn.sigmoid(logits)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits , images))

In [5]:
optimizer = tf.train.RMSPropOptimizer(1e-3)
grads_and_vars = optimizer.compute_gradients(loss)

new_grads_and_vars = [(tf.clip_by_value(gv[0], -1, 1), gv[1]) for gv in grads_and_vars]
optim = optimizer.apply_gradients(new_grads_and_vars)

In [6]:
def binarize(images):
    return (np.random.uniform(size=images.shape) < images).astype('float32')

In [7]:
sess = tf.Session()
saver = tf.train.Saver()

sess.run(tf.initialize_all_variables())
saver.restore(sess, "./model_checkpoint/model.ckpt")

In [8]:
def save_images(images, height, width, n_row, n_col,epoch,
                cmin=0.0, cmax=1.0, directory="./", prefix="sample"):
    images = images.reshape((n_row, n_col, height, width))
    images = images.transpose(1, 2, 0, 3)
    images = images.reshape((height * n_row, width * n_col))

    filename = '%s_%s.jpg' % (prefix,epoch)
    scipy.misc.toimage(images, cmin=cmin, cmax=cmax).save(os.path.join(directory, filename))

In [9]:
batch_images = binarize(mnist.test.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1]))

In [10]:
save_images(batch_images,28,28,10,10,1,prefix='partial')

In [11]:
samples = np.copy(batch_images)
samples[:,14:,:,:] = 0.

for i in xrange(14,28):
    for j in xrange(28):
        for k in xrange(1):
            next_sample = binarize(sess.run(outputs , feed_dict={images: samples}))
            samples[:, i, j, k] = next_sample[:, i, j, k]
            
save_images(samples,28,28,10,10,2,prefix='partial')