## Import

In [49]:
import tensorflow as tf
import numpy as np

## Configure

In [50]:
DATA_SIZE = 10000
DATA_DIM = 128
BATCH_SIZE = 64
NUM_EPOCHS = 20
NUM_HIDDEN = 128
N_VISIBLE = IMAGE_SIZE*IMAGE_SIZE*NUM_CHANNELS
SAMPLE_HIDDEN = False
SAMPLE_VISIBLE = False
N_GIBBS_STEPS = 10
LR = 0.09
MOMENTUM = 0.1
PERSISTENT_CHAIN = False

training_data = np.random.randn(DATA_SIZE,DATA_DIM)

## Setup persistent variables

In [51]:
with tf.device("/cpu:0"):
    visible = tf.placeholder(tf.float32,shape=(BATCH_SIZE,DATA_DIM))
    weights = tf.Variable(
      tf.truncated_normal([NUM_HIDDEN, DATA_DIM],
                          stddev=0.01,
                          seed=SEED))
    bias_h = tf.Variable(tf.zeros([NUM_HIDDEN]))
    bias_v = tf.Variable(tf.zeros([DATA_DIM]))

## define helper functions

In [52]:
def v_h(v, sample=True):
    p_h = tf.sigmoid(tf.matmul(v,weights,transpose_b=True) +bias_h)
    if sample:
        thresh = tf.random_uniform([BATCH_SIZE, NUM_HIDDEN])
        h = tf.to_float(p_h > thresh)
    else:
        h = p_h
    return p_h,h
  
def h_v(h, sample=True):
    p_r = tf.sigmoid(tf.matmul(h,weights) + bias_v)
    if sample:
        thresh = tf.random_uniform([BATCH_SIZE, N_VISIBLE])
        r = tf.to_float(p_r > thresh)
    else:
        r = p_r
    return p_r,r

def Energy(v,h):
    hwvr = tf.reduce_sum(h*tf.matmul(v,weights, transpose_b=True),reduction_indices=1)
    bh = tf.matmul(h,tf.expand_dims(bias_h,-1))
    bv = tf.matmul(v,tf.expand_dims(bias_v,-1))
    return -(hwvr + bh + bv)
 
def FreeEnergy(v):
    e = tf.reduce_sum(tf.log(1 + tf.exp( tf.matmul(v,weights,transpose_b=True)+bias_h)),reduction_indices=1)
    bv = tf.matmul(v,tf.expand_dims(bias_v,-1)) 
    return -(bv + e)

## Build a model graph

In [53]:
def model(data):
    """The Model definition."""
    # Positive Phase
    p_h_p,h_p = v_h(data, SAMPLE_HIDDEN)
     
    # Get positive phase energy
    F_p = FreeEnergy(data)
    if PERSISTENT_CHAIN:
      _,h_t = v_h(persistent_recon, SAMPLE_HIDDEN)
    else:
      h_t = h_p
    # Gibbs chain  
    for i in range(N_GIBBS_STEPS-1):
      _,r = h_v(h_t,SAMPLE_VISIBLE)
      _,h_t = v_h(r, SAMPLE_HIDDEN)
    p_r,r = h_v(h_t,SAMPLE_VISIBLE)
    if PERSISTENT_CHAIN:
      r = persistent_recon.assign(r) 
    p_h_n,h_n = v_h(r, SAMPLE_HIDDEN)
   
    # Get negative phase energy
    F_n = FreeEnergy(r)   
    return F_p - F_n, p_h_p, p_h_n, r

# Build computation graph
F,h_p, h_n, recon = model(visible)

## Create a gradient descent optimizer

In [54]:
loss = tf.reduce_mean(F)
optimizer = tf.train.MomentumOptimizer(LR,MOMENTUM).minimize(
                      loss,
                      var_list=[weights, bias_h,bias_v])

## Run

In [55]:
with tf.Session() as s:
    tf.initialize_all_variables().run()
    for step in xrange(NUM_EPOCHS * DATA_SIZE // BATCH_SIZE):
        offset = (step * BATCH_SIZE) % (DATA_SIZE - BATCH_SIZE)
        batch_data = training_data[offset:(offset + BATCH_SIZE), :]
        feed_dict = {visible: batch_data}
        _, l = s.run([optimizer, loss],feed_dict=feed_dict)
        print("Step {}, Mean Free Energy {} ".format(step,l))

Step 0, Mean Free Energy 0.547564923763 
Step 1, Mean Free Energy -42.4305877686 
Step 2, Mean Free Energy -53.8366165161 
Step 3, Mean Free Energy -48.8881645203 
Step 4, Mean Free Energy -54.1671447754 
Step 5, Mean Free Energy -59.2446670532 
Step 6, Mean Free Energy -64.2504272461 
Step 7, Mean Free Energy -76.163848877 
Step 8, Mean Free Energy -73.0971374512 
Step 9, Mean Free Energy -68.0412750244 
Step 10, Mean Free Energy -64.0356140137 
Step 11, Mean Free Energy -72.3898391724 
Step 12, Mean Free Energy -68.5406188965 
Step 13, Mean Free Energy -65.2944869995 
Step 14, Mean Free Energy -69.4001541138 
Step 15, Mean Free Energy -67.0279846191 
Step 16, Mean Free Energy -71.4243392944 
Step 17, Mean Free Energy -68.0723114014 
Step 18, Mean Free Energy -70.7547454834 
Step 19, Mean Free Energy -69.087890625 
Step 20, Mean Free Energy -65.4024734497 
Step 21, Mean Free Energy -72.6742935181 
Step 22, Mean Free Energy -71.0883789062 
Step 23, Mean Free Energy -62.0577316284 
Step

KeyboardInterrupt: 