In [2]:
from sklearn.metrics import mean_squared_error as mse
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

def generate_points(k):
    """sample x,y examples from a randomly initialized sine curve.
    """
    n_points = 100
    
    phase = np.random.uniform(low=0, high=2*np.pi)
    ampl = np.random.uniform(low=0.1, high=5)
    X = np.linspace(-5, 5, n_points)
    y = np.sin(X + phase) * ampl
    

    keys = np.random.choice(np.arange(n_points), size=k)
    return (X[keys], y[keys], X, y)

In [3]:
# define a simple model

tf.reset_default_graph()

n_hidden = 64
n_classes = 1
n_features = 1

X_ = tf.placeholder(tf.float32, shape=[None, n_features])
y_ = tf.placeholder(tf.float32, shape=[None, n_classes])

with tf.variable_scope("parameters"):
    w1 = tf.Variable(tf.random_uniform([n_features, n_hidden]))
    b1 = tf.Variable(tf.random_uniform([n_hidden]))
    w2 = tf.Variable(tf.random_uniform([n_hidden, n_hidden]))
    b2 = tf.Variable(tf.random_uniform([n_hidden]))
    w3 = tf.Variable(tf.random_uniform([n_hidden, n_classes]))
    b3 = tf.Variable(tf.random_uniform([n_classes]))

with tf.variable_scope("model"):
    z1 = tf.matmul(X_, w1) + b1
    fc1 = tf.nn.tanh(z1)
    z2 = tf.matmul(fc1, w2) + b2
    fc2 = tf.nn.tanh(z2)
    z3 = tf.matmul(fc2, w3) + b3

loss = tf.reduce_mean(tf.square(z3 - y_))
op = tf.train.AdamOptimizer(1e-2).minimize(loss)

init = tf.global_variables_initializer()

saver = tf.train.Saver()

In [21]:
# lets pretrain to find a good initialization parameters

n_tasks = 1000  # number of tasks (eg sine curves), should be at least 10000
n_examples = 50  # number of examples per task; reptile uses 50 examples per task
mb_size = 10  # reptile uses mb size of 10
n_epochs = 5  # k > 1  # we can train as much as we'd like
save_path = "model/reptile.ckpt"

with tf.Session() as sess:
    sess.run(init)
    
    # randomly sample a task from p(Tasks)...
    for task in range(n_tasks):
        
        # collect old parameters
        w1_a, b1_a, w2_a, b2_a, w3_a, b3_a = sess.run([w1, b1, w2, b2, w3, b3])
        
        # fetch x,y examples from that task
        xs, ys, _, _ = generate_points(n_examples)
    
        # parameter update
        for _ in range(n_epochs):
            
            # minibatch update
            for start in range(0, n_examples, mb_size):
                xs_b = xs[start:start+mb_size]
                ys_b = ys[start:start+mb_size]
                _ = sess.run(op, feed_dict={X_: xs_b.reshape(mb_size,1), 
                                            y_: ys_b.reshape(mb_size,1)})
        
        # collect new parameters
        w1_b, b1_b, w2_b, b2_b, w3_b, b3_b = sess.run([w1, b1, w2, b2, w3, b3])
        
        # calculate "meta gradient"
        outerstepsize = .1 * (1 - task / n_tasks)
        w1_c = w1_a + (w1_b - w1_a) * outerstepsize
        b1_c = b1_a + (b1_b - b1_a) * outerstepsize
        
        w2_c = w2_a + (w2_b - w2_a) * outerstepsize
        b2_c = b2_a + (b2_b - b2_a) * outerstepsize
        
        w3_c = w3_a + (w3_b - w3_a) * outerstepsize
        b3_c = b3_a + (b3_b - b3_a) * outerstepsize
        
        # update model with new parameters
        w1.load(w1_c, sess)
        b1.load(b1_c, sess)
        w2.load(w2_c, sess)
        b2.load(b2_c, sess)
        w3.load(w3_c, sess)
        b3.load(b3_c, sess)
        
        # calculate loss
        if task % (n_tasks / 10) == 0:
            loss_ = sess.run(loss, feed_dict={X_: xs.reshape(n_examples,1), y_: ys.reshape(n_examples,1)})
            print("epoch {}, loss {:.3f}".format(task, loss_))
    
    saver.save(sess, save_path, global_step=task)

print("done!")

epoch 0, loss 721.898
epoch 100, loss 1.986
epoch 200, loss 1.599
epoch 300, loss 0.694
epoch 400, loss 0.048
epoch 500, loss 8.748
epoch 600, loss 2.105
epoch 700, loss 9.819
epoch 800, loss 0.674
epoch 900, loss 0.806
done!


In [41]:
# compare performance between (a) randomly initialized network versus (b) pretrained network

n_examples = 20  # we only get a few examples (1-5) for each task
n_epochs = 50  # we also want to be be quick, only a few steps
xs, ys, xf, yf = generate_points(n_examples)  # generate a single set of xy so we can compare

In [42]:
# first, lets see how well a randomly initalized model would do on new tasks

with tf.Session() as sess3:
    sess3.run(init)
        
    for _ in range(n_epochs):
        _ = sess3.run(op, feed_dict={X_: xs.reshape(n_examples,1), 
                                     y_: ys.reshape(n_examples,1)})

    # calculate predictions on the entire dataset
    y_pred_random = sess3.run(z3, feed_dict={X_: xs.reshape(n_examples,1)})

In [43]:
# then, lets use a well initialized/conditioned network - it should be able to learn from a few examples
with tf.Session() as sess2:
    sess2.run(init)
    
    # restore model
    ckpt_folder = tf.train.get_checkpoint_state("model")
    load_path = ckpt_folder.model_checkpoint_path
    saver.restore(sess2, load_path)

    for _ in range(n_epochs):
        _ = sess2.run(op, feed_dict={X_: xs.reshape(n_examples,1), y_: ys.reshape(n_examples,1)})
    y_pred_initialized = sess2.run(z3, feed_dict={X_: xs.reshape(n_examples,1)})

INFO:tensorflow:Restoring parameters from model/reptile.ckpt-49999


In [None]:
# a well-initialized network can learn from a few examples
plt.scatter(xs, y_pred_random, color ='g')  # random
plt.scatter(xs, y_pred_initialized, color ='b', marker="*", s=202)  # initialized
plt.plot(xf, yf, linewidth=2, color='r', linestyle='--')  # true