<a href="https://colab.research.google.com/github/darshank528/Project-STORM/blob/master/Storm_Optimizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Creating STORM optimizer class as per algorithm in the paper https://arxiv.org/abs/1905.10018

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer

class StormOptimizer(Optimizer):    
    # Storing the parameters required in defaults dictionary
    # lr-->learning rate
    # c-->parameter to be swept over logarithmically spaced grid as per paper
    # w and k to be set as 0.1 as per paper
    # momentum-->dictionary storing model params as keys and their momentum term as values at each iteration(denoted by 'd' in paper)
    # gradient--> dictionary storing model params as keys and their gradients till now in a list as values(denoted by '∇f(x,ε)' in paper)
    # sqrgradnorm-->dictionary storing model params as keys and their sum of norm ofgradients till now as values(denoted by '∑G^2' in paper)

    def __init__(self,params,lr=0.1,c=100,momentum={},gradient={},sqrgradnorm={}):
        defaults = dict(lr=lr,c=c,momentum=momentum,sqrgradnorm=sqrgradnorm,gradient=gradient)
        super(StormOptimizer,self).__init__(params,defaults)

    # Returns the state of the optimizer as a dictionary containing state and param_groups as keys
    def __setstate__(self,state):
        super(StormOptimizer,self).__setstate__(state)

    # Performs a single optimization step for parameter updates
    def step(self,closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        # param_groups-->a dict containing all parameter groups
        for group in self.param_groups:
           # Retrieving from defaults dictionary
           learn_rate = group['lr']
           factor = group['c']
           momentum = group['momentum']
           gradient = group['gradient']
           sqrgradnorm = group['sqrgradnorm']

           # Update step for each parameter present in param_groups
           for p in group['params']:
                # Calculating gradient('∇f(x,ε)' in paper)
                if p.grad is None:
                    continue
                dp = p.grad.data

                # Storing all gradients in a list
                if p in gradient:
                    gradient[p].append(dp)
                else:
                    gradient.update({p:[dp]})

                # Calculating and storing ∑G^2in sqrgradnorm
                if p in sqrgradnorm:
                    sqrgradnorm[p] = sqrgradnorm[p] + torch.pow(torch.norm(dp),2)
                else:
                    sqrgradnorm.update({p:torch.pow(torch.norm(dp),2)})

                # Updating learning rate('η' in paper)
                power = 1.0/3.0
                scaling = torch.pow((0.1 + sqrgradnorm[p]),power)
                learn_rate = learn_rate/(float)(scaling)

                # Calculating 'a' mentioned as a=cη^2 in paper(denoted 'c' as factor here)
                a = min(factor*learn_rate**2.0,1.0)

                # Calculating and storing the momentum term(d'=∇f(x',ε')+(1-a')(d-∇f(x,ε')))
                if p in momentum:
                    momentum[p] = gradient[p][-1] + (1-a)*(momentum[p]-gradient[p][-2])
                else:
                    momentum.update({p:dp})

                # Updation of model parameter p                
                p.data = p.data-learn_rate*momentum[p]
                learn_rate = group['lr']
        
        return loss
