In [1]:
from collections import defaultdict

import torch
from torch.optim.optimizer import Optimizer


class Lookahead(Optimizer):
    r"""PyTorch implementation of the lookahead wrapper.

    Lookahead Optimizer: https://arxiv.org/abs/1907.08610
    """

    def __init__(self, optimizer, la_steps=5, la_alpha=0.8, pullback_momentum="none"):
        """optimizer: inner optimizer
        la_steps (int): number of lookahead steps
        la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer.
        pullback_momentum (str): change to inner optimizer momentum on interpolation update
        """
        self.optimizer = optimizer
        self._la_step = 0  # counter for inner optimizer
        self.la_alpha = la_alpha
        self._total_la_steps = la_steps
        pullback_momentum = pullback_momentum.lower()
        assert pullback_momentum in ["reset", "pullback", "none"]
        self.pullback_momentum = pullback_momentum

        self.state = defaultdict(dict)

        # Cache the current optimizer parameters
        for group in optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['cached_params'] = torch.zeros_like(p.data)
                param_state['cached_params'].copy_(p.data)
                if self.pullback_momentum == "pullback":
                    param_state['cached_mom'] = torch.zeros_like(p.data)

    def __getstate__(self):
        return {
            'state': self.state,
            'optimizer': self.optimizer,
            'la_alpha': self.la_alpha,
            '_la_step': self._la_step,
            '_total_la_steps': self._total_la_steps,
            'pullback_momentum': self.pullback_momentum
        }

    def zero_grad(self):
        self.optimizer.zero_grad()

    def get_la_step(self):
        return self._la_step

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)

    def _backup_and_load_cache(self):
        """Useful for performing evaluation on the slow weights (which typically generalize better)
        """
        for group in self.optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['backup_params'] = torch.zeros_like(p.data)
                param_state['backup_params'].copy_(p.data)
                p.data.copy_(param_state['cached_params'])

    def _clear_and_load_backup(self):
        for group in self.optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                p.data.copy_(param_state['backup_params'])
                del param_state['backup_params']

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def step(self, closure=None):
        """Performs a single Lookahead optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = self.optimizer.step(closure)
        self._la_step += 1

        if self._la_step >= self._total_la_steps:
            self._la_step = 0
            # Lookahead and cache the current optimizer parameters
            for group in self.optimizer.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state['cached_params'])  # crucial line
                    param_state['cached_params'].copy_(p.data)
                    if self.pullback_momentum == "pullback":
                        internal_momentum = self.optimizer.state[p]["momentum_buffer"]
                        self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
                            1.0 - self.la_alpha, param_state["cached_mom"])
                        param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
                    elif self.pullback_momentum == "reset":
                        self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

        return loss


In [60]:
import numpy as np
#for alpha in np.arange(0.01,1.01,0.01): # 100 evenly spaced values of alpha in range (0,1]
#    pass
#    #Lookahead(la_steps=5, la_alpha=alpha)

dim = 2
A = np.array([[1,0],[0,1]])

def noisyquadloss(x,A):
    Sigma = np.linalg.inv(A)
    c =np.dot(np.random.randn(2),Sigma)
    return 1/2 * np.dot(np.dot(x-c,A), x-c)

def best_var_sgd(lr,A):
    I = np.identity.len(A)
    Sigma = np.linalg.inv(A)
    return lr**2 * A**2 * Sigma**2 / (I-(I-gamma*A)**2)

def best_var_la(alpha,lr,A,k=5):
    I = np.identity(len(A))
    numerator = alpha**2 * (I - (I - lr*A)**(2*k))
    denominator = numerator + 2*alpha*(1-alpha)*(I-(I-lr*A)**k)
    return numerator/denominator * best_var_sgd(lr,A)

def SGDexpweight(alpha,lr,A,x,t=1000): # Appendix A
    I = np.identity(len(A))
    return (I-lr*A)**t

def SGDvarweight(alpha,lr,A,x,t=1000): # Appendix A
    I = np.identity(len(A))
    return (I-lr*A)**(2*t)*x + t * lr**2 * A**2 * np.linalg.inv(A)

def lookaheadupdateweightexp(alpha, lr, A, k=5): # Lemma 1, eqn 10
    I = np.identity(len(A))
    return 1 - alpha + alpha*(I-lr*A)**k

def lookaheadupdateweightvar(alpha, lr, A, k=5): # Lemma 1 eqn 11
    I = np.identity(len(A))
    p = (1-alpha + alpha*(I-lr*A)**k)**2
    q = 0
    matrix = I -lr*A
    Sigma = np.linalg.inv(A)
    for i in range(k):
        q += matrix**(2*i) * lr**2 * A**2 * Sigma
    return (p,q)
    

theta = [4,2]
noisyquadloss(theta,A)




In [None]:
lr = 0.05

exp
exploss = 0 # Eqn 5
Sigma = np.linalg.inv(A)
for i in range(dim):
    exptheta = 0
    vartheta = 0
    exploss += 1/2 * A[i][i] * (exptheta**2+vartheta+Sigma[i][i])