In [None]:
import os
import shutil

# We do not need a dataset so we load the fake input
from meta_learning.backend.tensorflow.dataset import noop as _dataset_noop

# We load the optimizers we care about.
from meta_learning.backend.tensorflow import optimizer as _optimizer

# The memory types we want to use.
from meta_learning.backend.tensorflow import memory as _memory

# Our memory requires some gradient implementations
from meta_learning.backend.tensorflow import gradient as _gradient


# The model we care about.
from meta_learning.backend.tensorflow.model import rosenbrock as _model_rosenbrock

# We want to save our model as easy as possible.
from meta_learning.backend.tensorflow import saver as _saver


from meta_learning.plot import rosenbrock as _plot_rosenbrock
from meta_learning.plot import utils as _plot_utils

from tf_utils import summaries as _summaries

In [None]:
GLOBALS = {}

# Some global variables for this notebook.
LOGS_DIR = '../logs'
PLOTS_DIR = '../plots'

MODEL_DIR_META = os.path.join(LOGS_DIR, 'example_rosenbrock/meta_sgd')
MODEL_DIR_SGD = os.path.join(LOGS_DIR, 'example_rosenbrock/sgd')

LEARNING_RATE = 0.001
MEM_LEARNING_RATE = 0.001
CLIP_GRAD = 10 
STEPS_PER_TRAIN_CALL = 100
LM_SCALE = 0.5
NUM_CENTERS = 100

def model_dir_iteration(model_dir, iteration):
    return os.path.join(model_dir, str(iteration))

def clean_model_dir(model_dir):
    if os.path.exists(model_dir):
        shutil.rmtree(model_dir)


In [None]:
# An example training loop.
def train(model_dir, optimizer_fn, saver_fn):
    model = _model_rosenbrock.Rosenbrock(model_dir=model_dir,
                                         joined=False,
                                         trainable=None) # Trainable only matters if we want to optimize only one dim.
    model_fn = model.create_model_fn(optimizer_fn=optimizer_fn, saver_fn=saver_fn)
    
    estimator = model.create_estimator(model_fn=model_fn)
    
    dataset = _dataset_noop.Noop(num_epochs=1,
                                 batch_size=1,
                                 shard_name='train')
    
    estimator.train(input_fn=dataset.create_input_fn(),
                    steps=STEPS_PER_TRAIN_CALL)
    
    
def run(base_model_dir, optimizer_fn, saver_fn):
    clean_model_dir(base_model_dir)
    model_dir = model_dir_iteration(base_model_dir, 1)
    train(model_dir, optimizer_fn, saver_fn)
    

# Just a easy way to plot the saved data.
def plot(base_model_dir):
    # We make sure that this is called again
    %matplotlib inline  
    experiment_results = _summaries.get_summary_save_tensor_dict(base_model_dir, 'rosenbrock')
    model_name = os.path.relpath(base_model_dir, LOGS_DIR)
    for exp_name, exp_data in sorted(experiment_results.items()):
        print('plot ', exp_name)
        _plot_rosenbrock.plot_func(PLOTS_DIR, 
                                   model_name + exp_name, 
                                   -1.0, 2.0, 100, 
                                   exp_data['x1'], exp_data['x2'])
     
    plots = {}
    for exp_name, exp_data in sorted(experiment_results.items()):
        plots[exp_name] = {'step': exp_data['step'], 'loss':exp_data['loss']}
        
    _plot_utils.plot_values_by_step(PLOTS_DIR, 
                            model_name, 
                            'loss', 
                            plots, logscale=True)

In [None]:

# We need to specify if we want to use the memory with a specificd gradient descent type.
# We specify our adam and sgd memory models:

meta_sgd_fn = _optimizer.Memory.init_fn(
    memory_init_fn=_memory.AdamStatic.init_fn(LEARNING_RATE,
                        memory_clip_grad=CLIP_GRAD,
                        memory_lm_scale=LM_SCALE,
                        memory_num_centers=NUM_CENTERS,
                        memory_learning_rate=MEM_LEARNING_RATE),
    gradient_init_fn=_gradient.GradientDescent.init_fn(LEARNING_RATE,
                                      clip_by_value=CLIP_GRAD))

sgd_fn = _optimizer.Reference.init_fn(
    gradient_init_fn=_gradient.GradientDescent.init_fn(LEARNING_RATE,
                                      clip_by_value=CLIP_GRAD))


run(MODEL_DIR_META, 
    meta_sgd_fn, 
    _saver.Standard.init_fn())


run(MODEL_DIR_SGD, 
    sgd_fn, 
    _saver.Standard.init_fn())

In [None]:
plot(MODEL_DIR_SGD)
plot(MODEL_DIR_META)