In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import random
from tqdm import tqdm_notebook as tqdm
import multiprocessing
import os.path
import csv
import copy
import joblib
from torchvision import datasets
import torchvision
import seaborn as sns; sns.set(color_codes=True)
sns.set_style("white")

In [3]:
# Change USE_CUDA to True when running on GPU
USE_CUDA  = False

def w(v):
    if USE_CUDA:
        return v.cuda()
    return v


In [4]:
!mkdir -p _cache
cache = joblib.Memory(location='_cache', verbose=0)


In [5]:
# This function is responsible for disabling the propagation of certain variable gradients
def detach_var(v):
    var = w(Variable(v.data, requires_grad=True))
    var.retain_grad()
    return var

In [19]:
def fit(optimizer_network, meta_optimizer, optimizee_obj_function, optimizee_network,
        iterations_to_optimize, iterations_to_unroll, out_mul,
        should_train = True):
    """
    Arguments: 
    - optimizer_network (the optimizer network we use, here the LSTM)
    - meta_optimizer (the optimizer of the optimizer network, e.g. Adam, SGD + nesterov, RMSprop, etc.)
    - optimizee_obj_function (the optimizee's objective function)
    - optimizee_network (the optimizee network)
    - epochs (total epochs for training)
    - iterations_to_optimize (iterations in every epoch)
    - should_train (if should_train is True, then we just train the optimizer, else we evaluate)
    """
    
    if should_train:
        optimizer_network.train()
    else:
        optimizer_network.eval()
        unroll = 1

    optimizee_obj_function = optimizee_obj_function(training=should_train)
    optimizee = w(optimizee_network())
    
    # Counting the parameters of the optimizee
    n_params = 0
    for param in optimizee.parameters():
        n_params += int(np.prod(param.size()))
        
    hidden_states = [w(Variable(torch.zeros(n_params, optimizer_network.hidden_size))) for _ in range(2)]
    cell_states = [w(Variable(torch.zeros(n_params, optimizer_network.hidden_size))) for _ in range(2)]
    
    losses_list = []
    
    if should_train:
        meta_optimizer.zero_grad()
        
    total_losses = None
    
    for iteration in range(iterations_to_optimize):
        
        # The loss of the current iteration
        current_loss = optimizee(optimizee_obj_function)
        
        # Since the objective function of the optimizer is equal to the sum of the optimizee's losses
        # we want to measure the loss of every iteration and add it to the total sum of losses
        if total_losses is None:
            total_losses = current_loss
        else:
            total_losses += current_loss
            
        losses_list.append(current_loss.data.cpu().numpy())
        # Here dloss/dx is computed for every parameter x that has requires_grad = True
        # These are accumulated into x.grad for every parameter x
        # This is equal to x.grad += dloss/dx
        
        # We get the optimizee's gradients but we also retain the graph because
        # we need to run backpropagation again when we optimize the optimizer
        current_loss.backward(retain_graph = True)
        
        offset = 0
        result_params = {}
        
        # These will be the new parameters. We will update all the parameters, cell and hidden states
        # by iterating through the optimizee's "all_named parameters"
        hidden_states2 = [w(Variable(torch.zeros(n_params, optimizer_network.hidden_size))) for _ in range(2)]
        cell_states2 = [w(Variable(torch.zeros(n_params, optimizer_network.hidden_size))) for _ in range(2)]
        
        for name, param in optimizee.all_named_parameters():
            current_size = int(np.prod(param.size()))
            # We want to disconnect the gradients of some variables but not all, each time.
            
            # We do this in order to disconnect the gradients of the offset:offset+current_size
            # parameters but still get the gradients of the rest.
            gradients = detach_var(param.grad.view(current_size, 1))
            
            # Call the optimizer and compute the new parameters
            updates, new_hidden, new_cell = optimizer_network(
                gradients,
                [h[offset:offset+current_size] for h in hidden_states],
                [c[offset:offset+current_size] for c in cell_states]
            )
            
            # Here we replace the old parameters with the new values
            for i in range(len(new_hidden)):
                hidden_states2[i][offset:offset+current_size] = new_hidden[i]
                cell_states2[i][offset:offset+current_size] = new_cell[i]
            
            result_params[name] = param + updates.view(*param.size()) * out_mul
            result_params[name].retain_grad()
        
        # If we have reached the number of iterations needed to update the optimizer
        # we run backprop on the optimizer network
        if iteration % iterations_to_unroll == 0:
            if should_train:
                # zero_grad() clears the gradients of all optimized tensors
                meta_optimizer.zero_grad()
                # we compute the gradient of the total losses  (i.e. the optimizer's loss function)
                # with respect to the optimizer's parameters
                total_losses.backward()
                # we finally perform the optimization step, i.e. the updates
                meta_optimizer.step()
                
            # Since we did the update on the optimizer network
            # we overwrite the total_losses
            total_losses = None
            
            # Here we detach the state variables because they are not propagated
            # to the graph (see Figure 2 of the paper for details)
            optimizee = w(optimizee_network(**{k: detach_var(v) for k, v in result_params.items()}))
            hidden_states = [detach_var(v) for v in hidden_states2]
            cell_states = [detach_var(v) for v in cell_states2]
        else:
            # Otherwise, we just create the next optimizee objective funtion
            optimizee = w(optimizee_network(**result_params))
            hidden_states = hidden_states2
            cell_states = cell_states2
            
    return losses_list

In [21]:
@cache.cache
def main_loop(optimizee_obj_function, optimizee_network, preprocessing = False,
        epochs = 20, iterations_to_optimize = 100, iterations_to_unroll = 20, 
        n_tests = 100, lr = 0.001, out_mul = 1.0):
    
    optimizer_network = w(OptimizerNetwork(preprocessing = preprocessing))
    
    # To construct an Optimizer you need to give it an iterable containing the parameters to optimize
    meta_optimizer = optim.Adam(optimizer_network.parameters(), lr = lr)
    
    # Initialize dummy variables for the best_net object and the best_loss
    best_net = None
    best_loss = 10000000000000000
    
    for _ in tqdm(range(epochs), 'epochs'):
        for _ in tqdm(range(iterations_to_optimize), 'iterations'):
            fit(optimizer_network, meta_optimizer, optimizee_obj_function, optimizee_network,
                iterations_to_optimize, iterations_to_unroll, out_mul,
                should_train = True)
            
        current_loss = (np.mean([
                np.sum(fit(optimizer_network, meta_optimizer, optimizee_obj_function, optimizee_network,
                iterations_to_optimize, iterations_to_unroll, out_mul,
                should_train = False))
                for _ in tqdm(range(n_tests), 'tests')
               ]))
        print(current_loss)
        
        if current_loss < best_loss:
            best_loss = current_loss
            best_net = copy.deepcopy(optimizer_network.state_dict())
    
    return best_loss, best_net

For our first experiment we are going to use randomly generated $W$ and $y$. The matrix $W$ will be of dimensions 10x10 and will correspond to the weights that we want to learn, and the vector $y$ will be a 10-element vector that will represent the labels. Our optimizer will try to find a 10-element vector $\theta$ that, when multiplied by $W$ will be as close as possible to y. Hence, our objective function that we want to minimize will be the squared error, i.e.:

$$
\sum^{n}_{k = 1} (w_{i}^T \cdot \theta_{i} - y_{i})^{2}
$$

where $w_{i}$ is the $i$-th column vector of the $W$ matrix.

In [22]:
class RandomQuadraticLoss:
    def __init__(self, **kwargs):
        self.W = w(Variable(torch.randn(10, 10)))
        self.y = w(Variable(torch.randn(10)))
    
    def get_loss(self, theta):
        return torch.sum((self.W.matmul(theta) - self.y) ** 2)
    

class QuadraticOptimizee(nn.Module):
    def __init__(self, theta = None):
        super().__init__()
        
        if theta is None:
            self.theta = nn.Parameter(torch.zeros(10))
        else:
            self.theta = theta
            
    def forward(self, target):
        return target.get_loss(self.theta)
    
    def all_named_parameters(self):
        return [('theta', self.theta)]
    
# Here we declare the optimizer. Here we have a parameter called preproc which is used in order to implement
# what is described as "Gradient preprocessing" in Appendix A. Generally the input to the optimizer network is
# a gradient which can get very high or very small values, especially when we're dealing with complex architectures.
# Due to this, the optimizer training is susceptible to high variance in gradient values, as neural nets prefer to
# deal with a relatively small range of values. Therefore a means of normalizing the gradients needs to be implemented.

# The preprocessing factor is the p > 0 parameter in that paper section, which controls how small gradients are disregarded. It has
# a default value of 10.0

class OptimizerNetwork(nn.Module):
    def __init__(self, preprocessing = False, hidden_size = 20, preprocessing_factor = 10.0):
        super(OptimizerNetwork, self).__init__()
        self.hidden_size = hidden_size
        
        if preprocessing:
            # Since we have the preprocessing flag enabled, we want the neural network
            # to have two arguments and not just the gradient (see the forward function)
            self.recurs = nn.LSTMCell(2, hidden_size)
        else:
            self.recurs = nn.LSTMCell(1, hidden_size)
            
        self.recurs2 = nn.LSTMCell(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, 1)
        self.preprocessing = preprocessing
        self.preprocessing_factor = preprocessing_factor
        self.preprocessing_threshold = np.exp(-preprocessing_factor)
        
    def forward(self, inp, hidden, cell):
        if self.preprocessing:
            inp = inp.data
            inp2 = w(torch.zeros(inp.size()[0], 2))
            
            # If the absolute value is greater or equal than the preprocessing threshold
            # (see the condition in the first part of the gradient winged formula) we pass
            # the log of the absolute value of the gradient divided by the preprocessing factor
            # as the first parameter, and we pass the sign of the gradient as the second parameter.
            keep_grads = torch.abs(inp) >= self.preprocessing_threshold
            inp2[:, 0][keep_grads] = torch.log(torch.abs(inp[keep_grads]) + 1e-8) / self.preprocessing_factor
            inp2[:, 1][keep_grads] = torch.sign(inp[keep_grads])
            
            # Else we pass -1 as the first parameter and then a scaled value of the gradient.
            inp2[:, 0][~keep_grads] = -1
            inp2[:, 1][~keep_grads] = float(np.exp(self.preprocessing_factor)) * inp[~keep_grads]
            inp = w(Variable(inp2))
            
        hidden0, cell0 = self.recurs(inp, (hidden[0], cell[0]))
        hidden1, cell1 = self.recurs2(hidden0, (hidden[1], cell[1]))
        
        return self.output(hidden1), (hidden0, hidden1), (cell0, cell1)

Now, we are going to try several values for the learning rate, in order to find the most promising one for our first experiment which is going to be to minimize the random quadratic loss we declared previously. We are going to use 20 epochs in order to find the best value for our learning rate.

In [26]:
for lr in tqdm([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001], 'all'):
    print('Testing lr:', lr)
    print(main_loop(RandomQuadraticLoss, QuadraticOptimizee, lr = lr)[0])

A Jupyter Widget

Testing lr: 1.0


A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

2.48049e+07


A Jupyter Widget

A Jupyter Widget

8249.89


A Jupyter Widget

A Jupyter Widget

54952.0


A Jupyter Widget

A Jupyter Widget

23963.0


A Jupyter Widget

A Jupyter Widget

30830.6


A Jupyter Widget

A Jupyter Widget

4034.28


A Jupyter Widget

A Jupyter Widget

3838.52


A Jupyter Widget

A Jupyter Widget

3356.86


A Jupyter Widget

A Jupyter Widget

1097.95


A Jupyter Widget

A Jupyter Widget

KeyboardInterrupt: 

Now we are going to see what the best loss is for lr = 0.003 if we train for 100 epochs.

In [None]:
best_loss, best_quadratic_optimizer = main_loop(RandomQuadraticLoss, QuadraticOptimizee, lr = 0.003, n_epochs = 100)
print(best_loss)

Now we will find the best learning rate values for a selection of standard optimization algorithms (namely Adam, RMSprop, SGD and SGD with nesterov momentum).

In [25]:
@cache.cache
def fit_normal(optimizee_objective_function, optimizee_network, optimizer,
               n_tests = 100, epochs = 100, **kwargs):
    results = []
    for i in tqdm(range(n_tests), 'tests'):
        objective_function = optimizee_objective_function(training = False)
        optimizee = w(optimizee_network())
        optimizer = optimizer(optimizee.parameters(), **kwargs)
        total_loss = []
        for _ in range(epochs):
            current_loss = optimizee(objective_function)
            
            total_loss.append(current_loss.data.cpu().numpy())
            
            optimizer.zero_grad()
            current_loss.backward()
            optimizer.step()
        results.append(total_loss)
    return results

def find_best_lr_normal(optimizee_objective_function, optimizee_network, optimizer, **extra_kwargs):
    best_loss = 10000000000000000.0
    best_lr = 0.0
    
    for lr in tqdm([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001], 'Learning Rates'):
        try:
            loss = best_loss + 1.0
            loss = np.mean([np.sum(s) for s in fit_normal(optimizee_objective_function, optimizee_network, optimizer, 
                                                          lr = lr, **extra_kwargs)])
        except RunTimeError:
            pass

        if loss < best_loss:
            best_loss = loss
            best_lr = lr
            
    return best_loss, best_lr
    

In [None]:
OPTIMIZER_SELECTION = [(optim.Adam, {}), (optim.RMSprop, {}), (optim.SGD, {}),
                       (optim.SGD, {'nesterov': True, 'momentum':0.9})]
OPTIMIZER_NAMES = ['Adam', 'RMSprop', 'SGD', 'SGD + NestMom']

for opt, kwargs in NORMAL_OPTS:
    print(find_best_lr_normal(RandomQuadraticLoss, QuadraticOptimizee, opt, **kwargs))

In [None]:
QUAD_LEARNING_RATES = [0.1, 0.03, 0.01, 0.01]
fit_data = np.zeros((100, 100, len(OPTIMIZER_NAMES) + 1))

for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LEARNING_RATES)):
    np.random.seed(1234)
    fit_data[:, :, i] = np.array(fit_normal(RandomQuadraticLoss, QuadraticOptimizee, opt, lr = lr, **extra_kwargs))
    
opt = w(OptimizerNetwork())
opt.load_state_dict(best_quadratic_optimizer)
np.random.seed(1234)
fit_data[:, :, len(OPTIMIZER_NAMES)] = np.array([fit(opt, None, RandomQuadraticLoss, QuadraticOptimizee, 
                                                1, 100, 100, outmul = 1.0, should_train = False) for _ in range(100)])


Finally, we will plot the learning curves for the meta-optimized network (LSTM) vs the selection of the alternative standard algorithms.

In [None]:
ax = sns.tsplot(data = fit_data[:, :, :], condition = OPTIMIZER_NAMES + ['LSTM'], linestyle = '--', 
                color = ['r', 'b', 'g', 'k', 'y'])
ax.lines[-1].set_linestyle('-')
ax.legend()
plt.yscale('log')
plt.xlabel('steps')
plt.ylabel('loss')
plt.title("Comparison of learning curves between a selection of standard optimizers (dashed lines) and the meta-learned optimizer (straight line) in learning random quadratic functions")
plt.show()