- A first draft at porting @lessw2020's RangerQH optimizer from fastai v1 to fastai v2
- MNIST code taken from the excellent torch.nn tutorial, [What is torach.NN Reall?](https://pytorch.org/tutorials/beginner/nn_tutorial.html)
- @lessw2020's RangerQH code here: https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer/blob/master/rangerqh.py

In [1]:
# !pip install torch torchvision feather-format kornia pyarrow --upgrade   > /dev/null
# !pip install git+https://github.com/fastai/fastai_dev                    > /dev/null

In [2]:
%reload_ext autoreload
%autoreload 2

from fastai2.basics           import *
from fastai2.vision.all       import *
from fastai2.medical.imaging  import *
from fastai2.callback.tracker import *
from fastai2.callback.all     import *

from pathlib import Path
import requests
import pickle
import gzip
from torch import nn, optim
import pdb

# NN Simple

In [3]:
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

In [4]:
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [5]:
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])
torch.Size([50000, 784])
tensor(0) tensor(9)


In [6]:
class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)

    def forward(self, xb):
        return self.lin(xb)
    
# class Mnist_CNN(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
#         self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
#         self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)

#     def forward(self, xb):
#         xb = xb.view(-1, 1, 28, 28)
#         xb = F.relu(self.conv1(xb))
#         xb = F.relu(self.conv2(xb))
#         xb = F.relu(self.conv3(xb))
#         xb = F.avg_pool2d(xb, 4)
#         return xb.view(-1, xb.size(1))

### RangerQH v2 attempt

In [7]:
#from fastai2.basics import Optimizer as Optimizerv2

def RangerQH_v2(params, lr=1e-3, mom=0.9, sqr_mom=0.99, 
                 eps=1e-8, wd=0., k=6, alpha=.5, decouple_wd=False, betas=(0.9, 0.999)): # betas, nus
    from functools  import partial
    steppers = [weight_decay] if decouple_wd else [l2_reg]
    steppers.append(partial(rangerqh_step, nus=(.7, 1.0)))
    stats = [average_grad, step_stat, partial(betas_stat,betas=betas), d_p_grad, exp_avg_stat, slow_buffer_stat]
    return Optimizer(params, steppers, stats=stats, lr=lr, k=k, alpha=alpha,  # betas=betas, nus=nus
                     mom=mom, sqr_mom=sqr_mom, 
                     #grad_avg=grad_avg, #sqr_avg=sqr_avg, 
                     eps=eps, wd=wd)

In [8]:
def betas_stat(state, p, betas, **kwargs):
    beta1, beta2 = betas[0], betas[1]
    if 'beta1_weight' not in state: state['beta1_weight'] = 0.0
    state['beta1_weight'] = 1.0 + (beta1 * state['beta1_weight'])
    
    if 'beta2_weight' not in state: state['beta2_weight'] = 0.0
    state['beta2_weight'] = 1.0 + (beta2 * state['beta2_weight'])  
    return state

def exp_avg_stat(state, p, **kwargs):
    # EXP AVG THINGS
    beta1_adj = 1.0 - (1.0 / state['beta1_weight'])
    beta2_adj = 1.0 - (1.0 / state['beta2_weight'])

    if 'exp_avg' not in state: state['exp_avg'] = torch.zeros_like(p.data)
    if 'exp_avg_sq' not in state: state['exp_avg_sq'] = torch.zeros_like(p.data)
    state['exp_avg'].mul_(beta1_adj).add_(1.0 - beta1_adj, state['d_p'])
    state['exp_avg_sq'].mul_(beta2_adj).add_(1.0 - beta2_adj, state['d_p_sq'])
    return state

def d_p_grad(state, p,  **kwargs):
    state['d_p'] = p.grad.data
    state['d_p_sq'] = state['d_p'].mul(state['d_p'])
    return state

def slow_buffer_stat(state, p, **kwargs):
    if 'slow_buffer' not in state: 
        state['slow_buffer'] =  torch.empty_like(p.data)
        state['slow_buffer'].copy_(p.data)
    return state

In [9]:
def rangerqh_step(p, lr, mom, sqr_mom, step, 
                  beta1_weight, beta2_weight, exp_avg, exp_avg_sq, slow_buffer, grad_avg, d_p, d_p_sq, nus,
                  eps, wd, k, alpha, **kwargs): # betas, nus, grad_avg, sqr_avg, 
        """
            Performs a single optimization step.
        """
        #nus=(.7, 1.0)
        nu_1 = nus[0]
        nu_2 = nus[1]
        
        if d_p.is_sparse:
            raise RuntimeError("QHAdam does not support sparse gradients")
            
        avg_grad = exp_avg.mul(nu_1)
        if nu_1 != 1.0:
            avg_grad.add_(1.0 - nu_1, d_p)

        avg_grad_rms = exp_avg_sq.mul(nu_2)
        if nu_2 != 1.0:
            avg_grad_rms.add_(1.0 - nu_2, d_p_sq)
            
        avg_grad_rms.sqrt_()
        if eps != 0.0:
            avg_grad_rms.add_(eps)
        
        p.data.addcdiv_(-lr, avg_grad, avg_grad_rms)
        
        # LOOKAHEAD STEPPER
        #integrated look ahead...
        #if param_state['step'] % self.k ==0: #group['k'] == 0:
        if step % k ==0: 
            #slow_p = param_state['slow_buffer'] #get access to slow param tensor
            slow_p = slow_buffer #get access to slow param tensor
            # CACLC
            slow_p.add_(alpha, p.data - slow_p)  #(fast weights - slow weights) * alpha
            # RETURN P
            p.data.copy_(slow_p)  # copy interpolated weights to RAdam param tensor
                
        #return loss
        return p

### Train

In [10]:
def nll(input, target):
    return -input[range(target.shape[0]), target].mean()

def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

def get_sgd_model():
    model = Mnist_Logistic()
    return model, optim.SGD(model.parameters(), lr=lr)
    
def get_ramgerqhv2_model():
    model = Mnist_Logistic()
    return model, RangerQH_v2(model.parameters(), lr=lr)

In [11]:
loss_func = nll
lr = 0.5   # learning rate
epochs = 50  # how many epochs to train for
bs = 128

opt_nms = ['SGD', 'RangerQH_v2']

# Initialise models
sgd_model, sgd_opt = get_sgd_model()
rangerqh_model, rangerqh_opt = get_ramgerqhv2_model()
    
for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        start_i = i * bs
        end_i = start_i + bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = sgd_model(xb)
        loss = loss_func(pred, yb)
        loss.backward()
        sgd_opt.step()
        sgd_opt.zero_grad()
    
print(f'{opt_nms[0]} Accuracy : {accuracy(sgd_model(x_valid), y_valid)}')

for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        start_i = i * bs
        end_i = start_i + bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = rangerqh_model(xb)
        loss = loss_func(pred, yb)
        loss.backward()
        rangerqh_opt.step()
        rangerqh_opt.zero_grad()
    
print(f'{opt_nms[1]} Accuracy : {accuracy(rangerqh_model(x_valid), y_valid)}')

SGD Accuracy : 0.6990000009536743
RangerQH_v2 Accuracy : 0.7013000249862671
