In [1]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import OneHotEncoder

In [2]:
print(tf.__version__)

1.3.0


In [3]:
from IPython.display import clear_output

In [4]:
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split

def load_dataset():
    # We first define a download function, supporting both Python 2 and 3.
    mnist = fetch_mldata('MNIST original')
    data = mnist['data'].reshape((70000, 28, 28, 1))
    target = mnist['target']
    # We can now download and read the training and test set images and labels.
    X_train, X_test, y_train, y_test = train_test_split(data, target, train_size=500)
    # We reserve the last 10000 training examples for validation.
    X_train, X_val = X_train[:-300], X_train[-300:]
    y_train, y_val = y_train[:-300], y_train[-300:]

    # We just return all the arrays in order, as expected in main().
    # (It doesn't matter how we do this as long as we can read them again.)
    return X_train, y_train, X_val, y_val, X_test, y_test

print("Loading data...")
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset()
total_size = X_train.shape[0]

Loading data...




In [5]:
y_oh = OneHotEncoder(sparse=False)
y_oh.fit(y_train.reshape((-1,1)))
y_train, y_val, y_test = map(lambda x: y_oh.transform(x.reshape((-1,1))).astype('float32'), [y_train, y_val, y_test])

In [6]:
X_train, X_val, X_test = map(lambda x: (x.astype('float32') - 122)/255, [X_train, X_val, X_test])

In [7]:
y_val.shape

(300, 10)

In [8]:
X_val.dtype

dtype('float32')

In [9]:
dscale = tf.Variable(0.5, trainable=False)

def create_priorkl_gauss_prior(pairs, prior_std):
    with tf.name_scope('KL'):
        for mu, logsigma in pairs:
            kl = -logsigma + (tf.exp(logsigma)**2 + mu**2)/(2*prior_std**2)
            kl = tf.reduce_sum(kl)
            tf.add_to_collection('KLS', kl)
            
def create_priorkl_mixture_prior(pairs, prior_std=None):
    from math import pi
    
    def gauss(x, mu, sigma):
        return tf.exp(-(x-mu)**2/(2*sigma**2))/(np.sqrt(2*pi)*sigma)
    
    def loggauss(x, mu, sigma):
        return -(x-mu)**2/(2*sigma**2) - np.sqrt(2*pi)*sigma
    
    num_samples = 10
    mixture_stds = [0.2,5.]
    mixture_weights = [dscale,1-dscale]
    
    with tf.name_scope('KL'):
        for mu, logsigma in pairs:
            sigma = tf.exp(logsigma)
            weight_shape = list(map(int, mu.shape))
            
            weight_sample = tf.random_normal([num_samples] + weight_shape)*sigma[tf.newaxis,...] + mu[tf.newaxis,...]
            nent = loggauss(weight_sample, mu[tf.newaxis,...], sigma[tf.newaxis,...])
            
            prior_dens = sum([w*gauss(weight_sample, 0, s) for w,s in zip(mixture_weights, mixture_stds)])
            xent = -tf.log(prior_dens)
            
            kls = tf.reduce_mean(nent + xent, axis=0)
            
            kl = tf.reduce_sum(kls)
            
            tf.add_to_collection('KLS', kl)
            
def create_priorkl(pairs, prior_std):
    create_priorkl_mixture_prior(pairs)
    #create_priorkl_gauss_prior(pairs, prior_std)
            
def conv(x, nbfilter, filtersize, name, lrep=True):
    prior_std = 1
    
    s = int(x.shape[-2])
    
    input_f = int(x.shape[-1])
    kernelshape = [filtersize, filtersize, input_f, nbfilter]
    
    with tf.variable_scope(name, initializer=tf.random_normal_initializer(stddev=0.05)):
        with tf.name_scope(name+'/'):
            kernel_mu = tf.get_variable('kernel_mu', shape=kernelshape)
            kernel_logsigma = tf.get_variable('kernel_logsigma', shape=kernelshape) - 3
            #kernel_logsigma = tf.clip_by_value(kernel_logsigma, -5, 4)
            kernel_sigma = tf.exp(kernel_logsigma)
            tf.summary.histogram('kernel_sigma', kernel_sigma)
            
            if lrep:
                pmu = tf.nn.conv2d(x, kernel_mu, [1,1,1,1], padding='SAME')
                pvar = tf.nn.conv2d(x**2, kernel_sigma**2, [1,1,1,1], padding='SAME') + 0.001
                p = tf.random_normal(tf.shape(pmu))*tf.sqrt(pvar) + pmu
            else:
                kernel = tf.random_normal(tf.shape(kernel_mu))*kernel_sigma + kernel_mu
                p = tf.nn.conv2d(x, kernel, [1,1,1,1], padding='SAME')
                
            bias_mu = tf.get_variable('bias_mu', shape=[1,1,1,nbfilter])
            bias_logsigma = tf.get_variable('bias_logsigma', shape=[1,1,1,nbfilter]) - 3
            #bias_logsigma = tf.clip_by_value(bias_logsigma, -5, 4)
            bias_sigma = tf.exp(bias_logsigma)
            
            create_priorkl([[kernel_mu, kernel_logsigma], [bias_mu, bias_logsigma]], prior_std)
            
            p += tf.random_normal(tf.shape(bias_mu))*bias_sigma + bias_mu
            
            p = tf.nn.relu(p)
            p.set_shape([None,s,s,nbfilter])
            return p
        
def dense(x, nneurons, name, act=tf.nn.relu, lrep=True):
    prior_std=1
    
    input_n = int(x.shape[-1])
    Wshape = [input_n, nneurons]
    with tf.variable_scope(name, initializer=tf.random_normal_initializer(stddev=0.05)):
        with tf.name_scope(name+'/'):
            W_mu = tf.get_variable('kernel_mu', shape=Wshape)
            W_logsigma = tf.get_variable('kernel_logsigma', shape=Wshape) - 3
            #W_logsigma = tf.clip_by_value(W_logsigma, -5, 4)
            W_sigma = tf.exp(W_logsigma)
            tf.summary.histogram('W_sigma', W_sigma)
            
            b_mu = tf.get_variable('bias_mu', shape=[1,nneurons])
            b_logsigma = tf.get_variable('bias_logsigma', shape=[1,nneurons]) - 3
            #b_logsigma = tf.clip_by_value(b_logsigma, -5, 4)
            b_sigma = tf.exp(b_logsigma)
            
            create_priorkl([[W_mu, W_logsigma], [b_mu, b_logsigma]], prior_std)
            
            if lrep:
                p_mu = tf.matmul(x, W_mu)
                p_sigma = tf.sqrt(tf.matmul(x**2, W_sigma**2) + 0.001)
                
                p = tf.random_normal(tf.shape(p_mu))*p_sigma + p_mu
            else:
                W = tf.random_normal(tf.shape(W_mu))*W_sigma + W_mu
                p = tf.matmul(x, W)
                
            p += tf.random_normal(tf.shape(b_mu))*b_sigma + b_mu
            p = act(p)
            
            p.set_shape([None, nneurons])
            return p

In [10]:
x_inp = tf.placeholder(tf.float32, [None,28,28,1])
y_inp = tf.placeholder(tf.float32, [None,10])
global_step = tf.get_variable('global_step',initializer=0, dtype=tf.int32)

kl_scaler = 1 + tf.train.exponential_decay(1000., global_step, 10000, 0.1)  #tf.placeholder_with_default(200., (), 'kl_scaler')

lrep = True

x = conv(x_inp, 40, 3, 'c1', lrep=lrep)
x = conv(x, 40, 3, 'c2', lrep=lrep)
x = tf.nn.max_pool(x, (1,2,2,1), [1,2,2,1], 'SAME')
x = conv(x, 40, 3, 'c3', lrep=lrep)
    
x = conv(x, 40, 3, 'c4', lrep=lrep)
x = tf.nn.max_pool(x, (1,2,2,1), [1,2,2,1], 'SAME')
    
x = conv(x, 20, 3, 'c5', lrep=lrep)  
x = tf.reshape(x, [-1, (int(x.shape[-2])**2)*int(x.shape[-1])])

x = dense(x, 20, 'd1', lrep=lrep)
x = dense(x, 10, 'd2', act=lambda x: x, lrep=lrep)

In [11]:
output = tf.nn.softmax(x)

In [12]:
batchsize = 30
kls = tf.get_collection('KLS')

logit = tf.placeholder_with_default(x, shape=[None,10])

l1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y_inp))
l2 = (batchsize*sum(kls)/len(X_train))/kl_scaler
loss = l1 + l2

In [13]:
#loss /= 500

In [14]:
def gtop(loss):
    with tf.name_scope('optimizer'):
        lr = tf.get_variable('learning_rate', initializer=0.001, trainable=False)
        
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        gvs = optimizer.compute_gradients(loss)
        capped_gvs = [(tf.clip_by_value(grad, -20., 20.), var) for grad, var in gvs if grad is not None]
        
        checks = [tf.check_numerics(x[0], message=x[1].name) for x in capped_gvs]
        #checks = []
        with tf.control_dependencies(checks):
            train_op = optimizer.apply_gradients(capped_gvs)
        return train_op

train_op = gtop(loss)

In [15]:
def train(X, y):
    return sess.run([loss, train_op], {x_inp:X, y_inp:y})[0]

def evaluate(X, y):
    nsamples = 20
    pred = np.zeros([len(y), 10])
    for _ in range(nsamples):
        pred += sess.run(output, {x_inp:X})
    return np.mean(np.argmax(pred, axis=-1) == np.argmax(y, axis=-1))

def eval_loss(X, y):
    return loss_sum.eval({x_inp:X, y_inp:y})

In [16]:
hist_sum = tf.summary.merge_all()

In [17]:
loss_sum = tf.summary.merge([tf.summary.scalar('logprob', l1), tf.summary.scalar('kl_scaler', kl_scaler),
                             tf.summary.scalar('kl', l2), 
                             tf.summary.scalar('ELBO', loss)])

In [18]:
!rm -R /tmp/bayes_mnist/

In [19]:
writer = tf.summary.FileWriter('/tmp/bayes_mnist')

In [20]:
writer.add_graph(tf.get_default_graph())

In [21]:
from tensorflow.python import debug as tf_debug

In [22]:
sess = tf.InteractiveSession()

In [23]:
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
#sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

In [24]:
tf.global_variables_initializer().run()

In [25]:
#sess.run([loss], {x_inp:X_val[:5], y_inp:y_val[:5]})

In [26]:
#%time evaluate(X_val, y_val)

In [27]:
%time evaluate(X_train, y_train)

CPU times: user 1.08 s, sys: 236 ms, total: 1.31 s
Wall time: 1.24 s


0.080000000000000002

In [28]:
%time evaluate(X_val, y_val)

CPU times: user 864 ms, sys: 120 ms, total: 984 ms
Wall time: 896 ms


0.080000000000000002

In [41]:
evaluate_val = tf.py_func(lambda: evaluate(X_val, y_val), [], tf.float64, stateful=True, name='eval_val')
evaluate_train = tf.py_func(lambda: evaluate(X_train, y_train), [], tf.float64, stateful=True, name='eval_train')

In [42]:
evaluate_sum_val = tf.summary.scalar('eval_results_val', evaluate_val)
evaluate_sum_train = tf.summary.scalar('eval_results_train', evaluate_train)

In [43]:
evaluate_sum = tf.summary.merge([evaluate_sum_train, evaluate_sum_val])

In [31]:
#!rm -R /tmp/esave/
#!mkdir /tmp/esave

In [32]:
saver = tf.train.Saver()

In [33]:
#saver.restore(sess, '/tmp/esave')

In [34]:
#saver.save(sess, '/tmp/esave')

In [35]:
#%time evaluate(X_train, y_train)

In [36]:
losses = []
accs = []
nans = []

In [37]:
#plt.plot(range(len(nans)), nans)
#plt.show()

In [38]:
gs_increment = global_step.assign_add(1)

In [None]:
for epoch in range(1000000):
    for ix in range(len(X_train)//batchsize - 1):
        ix0 = ix*batchsize
        ix1 = (ix+1)*batchsize
        X, y = X_train[ix0:ix1], y_train[ix0:ix1]
        
        loss_ = train(X,y)
    
    gs_increment.eval()
        
    if epoch % 10 == 0:
        writer.add_summary(hist_sum.eval(), global_step=global_step.eval())
        writer.add_summary(eval_loss(X,y), global_step=global_step.eval())
        
        if epoch % 30:
            writer.add_summary(evaluate_sum.eval(), global_step=global_step.eval())

In [None]:
123