## AdamW benchmarking

This is to benchmark an implementation of https://arxiv.org/abs/1711.05101

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
# This file contains all the main external libs we'll use
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
import fastai.optim.adamw as adamwlib
import fastai.optim.sgdw as sgdwlib
import matplotlib.pyplot as plt
%matplotlib inline

### Common stuff for all experiments

In [3]:
def Get_SGD_Momentum(momentum=0.9):
    return lambda *args, **kwargs: optim.SGD(*args, momentum=momentum, **kwargs)

def Get_Adam():
    return lambda *args, **kwargs: optim.Adam(*args, **kwargs)

def Get_AdamW():
    # In this Adam, the weight will get decayed
    return lambda *args, **kwargs: adamwlib.AdamW(*args, **kwargs)

def Get_SGDW(momentum=0.9):
    # In this Adam, the weight will get decayed
    return lambda *args, **kwargs: sgdwlib.SGDW(*args, momentum=momentum, **kwargs)


In [4]:
import pickle

def save_list(fname, l):
    with open(fname, "wb") as fp:
        pickle.dump(l, fp)
        
def read_list(fname):
    with open(fname, "rb") as fp:
        return pickle.load(fp)

### This is a common function which does the training. 

The only thing it asks for is the optimizer, and the initial LR for that optimizer. Hence we are comparing optimizers keeping all things same.

In [5]:
def experiment(optimizer, lr=1e-3, find_lr=False, use_wd_schedule=False):
    sz = 224
    bs = 24
    arch=resnet152
    cycle_len=2
    cycle_mult=2
    num_cycles = 2
    lr = lr
    weight_decay = 0.025 # As used in the paper https://arxiv.org/abs/1711.05101
    PATH = "/home/as/datasets/fastai.dogscats"
    
    tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
    data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs)
   
    
    if find_lr:
        learn = ConvLearner.pretrained(arch, data, precompute=True, opt_fn=optimizer)
        lrf=learn.lr_find()
        learn.sched.plot()
        return
    
    learn = ConvLearner.pretrained(arch, data, precompute=True, opt_fn=optimizer)
    learn.fit(lr, 1, wds=weight_decay, use_wd_schedule=use_wd_schedule)
    print('Now with precompute as False')
    learn.precompute=False
    learn.fit(lr, num_cycles, wds=weight_decay, use_wd_schedule=use_wd_schedule, cycle_len=cycle_len, cycle_mult=cycle_mult)
    
    loss = learn.sched.losses
    fig = plt.figure(figsize=(10, 5))
    plt.plot(loss)
    plt.show()
    learn.sched.plot_lr()
    
    return learn.sched.losses

### SGDR with restarts

In [None]:
sgdm = Get_SGD_Momentum()
loss_sgdm = experiment(sgdm, find_lr=True)

In [None]:
sgdm = Get_SGD_Momentum()
loss_sgdm = experiment(sgdm, lr=1e-3)
save_list('sgdm_loss.txt', loss_sgdm)

### Vanilla Adam with fixed weight decay and restarts

In [None]:
adam = Get_Adam()
loss_adam = experiment(adam, find_lr=True)

**Train**

In [None]:
adam = Get_Adam()
loss_adam = experiment(adam, 1e-4)
save_list('adam_loss.txt', loss_adam)

### AdamW with dynamic weight decay and restarts

In [None]:
adamw = Get_AdamW()
loss_adamw = experiment(adamw, find_lr=True, use_wd_schedule=True)

In [None]:
adamw = Get_AdamW()
loss_adamw = experiment(adamw, 1e-4, use_wd_schedule=True)
save_list('adam_lossw.txt', loss_adamw)

### SGDW with dynamic weight decay and restarts

In [None]:
sgdw = Get_SGDW(0.9)
loss_sgdw = experiment(sgdw, find_lr=True, use_wd_schedule=True)

In [None]:
sgdw = Get_SGDW(0.9)
loss_sgdw = experiment(sgdw, 1e-3, use_wd_schedule=True)
save_list('loss_sgdw.txt', loss_sgdw)

### Plot all of them

In [None]:
fig=plt.figure(figsize=(15, 10))
plt.plot(loss_adam, c='red', label='Adam')
plt.plot(loss_sgdm, c='blue', label='SGDM')
plt.plot(loss_adamw, c='green', label='AdamW')
plt.plot(loss_sgdw, c='black', label='SGDW')
plt.legend()
plt.show()

### Scratch

In [None]:
"""
adamw = Get_AdamW()

sz = 32
bs = 64
PATH = "/home/as/datasets/fastai.dogscats"
tfms = tfms_from_model(resnet34, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
arch=resnet34
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs)

learn_adam = ConvLearner.pretrained(arch, data, precompute=True, opt_fn=adamw)
lrf=learn_adam.lr_find()
learn_adam.sched.plot()

lr = 1e-3
wds=0.025
cycle_len=1
cycle_mult=1
num_cycles = 1

data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs)
learn_adam = ConvLearner.pretrained(arch, data, precompute=True, opt_fn=adamw
                                   )
learn_adam.fit(lr, 1, wds=wds)
print('Now with precompute as False')
learn_adam.precompute=False
learn_adam.fit(lr, num_cycles, wds=wds, cycle_len=cycle_len, cycle_mult=cycle_mult)
"""