In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from keras.utils import np_utils

plt.style.use("fivethirtyeight")
%matplotlib inline

In [None]:
# fetch data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# momentum hyperparam
n_dimensions = mnist.train.images[0].shape[0]
n_classes = mnist.train.labels[0].shape[0]
beta = 0.9  # friction... how much of the previous momentum vector do you keep?
lr = 1e-2  # step size aka learning rate

In [None]:
# construct graph

tf.reset_default_graph()

with tf.variable_scope("inputs"):
    X_ = tf.placeholder(tf.float32, [None, n_dimensions], name="X")
    y_ = tf.placeholder(tf.float32, [None, n_classes], name="y")

with tf.variable_scope("model"):
    fc1 = tf.layers.dense(inputs=X_, units=n_classes, name='fc1', activation=tf.nn.relu)  
    
with tf.variable_scope('model/fc1', reuse=True):
    w = tf.get_variable('kernel')

probs = tf.nn.softmax(fc1)
loss = tf.losses.log_loss(labels=y_, predictions=probs)
grad = tf.gradients(loss, w1)[0]  # raw gradients of loss function wrt params

m_vector = tf.Variable(tf.zeros([784, 10]), name="momentum")  # accumulates gradients over n_epochs

with tf.variable_scope("step"):
    
    # update momentum vector before updating parameters
    update_m_vector = m_vector.assign(beta * m_vector - (lr * grad))
    update_params = w.assign(w + m_vector)
    
init = tf.global_variables_initializer()

In [None]:
n_epochs = 1000
batch_size = 32

ls_loss = []

with tf.Session() as sess:  
    sess.run(init)
    
    for epoch in range(n_epochs):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        
        # first, update momentum vector
        _ = sess.run(update_m_vector, feed_dict={X_: batch_xs, y_: batch_ys})
        
        # then, update parameter
        _, loss_ = sess.run([update_params, loss], feed_dict={X_: batch_xs, y_: batch_ys})
        ls_loss.append(loss_)
        
# pretty plot
rolling_plot = pd.Series(ls_loss).rolling(window=50).mean()
plt.plot(rolling_plot)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.title("momentum optimizer")