In [0]:
from fastai.vision import *

In [0]:
path = untar_data(URLs.IMAGEWOOF)

In [0]:
data = (ImageList.from_folder(path).split_by_folder(valid='val')
            .label_from_folder().transform(([flip_lr(p=0.5)], []), size=128)
            .databunch(bs=64, num_workers=2)
            .presize(128, scale=(0.35,1))
            .normalize(imagenet_stats))

In [0]:
import torch, math
from torch.optim.optimizer import Optimizer

# RAdam + LARS
class Ralamb(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(Ralamb, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(Ralamb, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('Ralamb does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # Decay the first and second moment running average coefficient
                # m_t
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                # v_t
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, radam_step = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        radam_step = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        radam_step = group['lr'] / (1 - beta1 ** state['step'])
                    buffered[2] = radam_step

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
                radam_norm = p_data_fp32.pow(2).sum().sqrt()
                if weight_norm == 0 or radam_norm == 0:
                    trust_ratio = 1
                else:
                    trust_ratio = weight_norm / radam_norm

                state['weight_norm'] = weight_norm
                state['adam_norm'] = radam_norm
                state['trust_ratio'] = trust_ratio

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-radam_step * trust_ratio, exp_avg, denom)
                else:
                    p_data_fp32.add_(-radam_step * trust_ratio, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss 

In [0]:
def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
     ralamb = Ralamb(params, *args, **kwargs)
     return Lookahead(ralamb, alpha, k)

In [0]:
opt_func=partial(Over9000, betas = (0.9,0.99), eps=1e-6)

In [0]:
from fastai.script import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *

In [0]:
lr = 1e-2

In [6]:
import torch.nn as nn
import torch,math,sys
import torch.utils.model_zoo as model_zoo
from functools import partial
#from ...torch_core import Module
from fastai.torch_core import Module

import torch.nn.functional as F  #(uncomment if needed,but you likely already have it)


class Mish(nn.Module):
    def __init__(self):
        super().__init__()
        print("Mish activation loaded...")

    def forward(self, x): 
        
        x = x *( torch.tanh(F.softplus(x)))

        return x

# or: ELU+init (a=0.54; gain=1.55)
act_fn = Mish()#nn.ReLU(inplace=True)

__all__ = ['MXResNet', 'mxresnet18', 'mxresnet34', 'mxresnet50', 'mxresnet101', 'mxresnet152']

# or: ELU+init (a=0.54; gain=1.55)
act_fn = Mish() #nn.ReLU(inplace=True)

class Flatten(Module):
    def forward(self, x): return x.view(x.size(0), -1)

def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)

def conv(ni, nf, ks=3, stride=1, bias=False):
    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)

def noop(x): return x

def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
    bn = nn.BatchNorm2d(nf)
    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
    layers = [conv(ni, nf, ks, stride=stride), bn]
    if act: layers.append(act_fn)
    return nn.Sequential(*layers)

class ResBlock(Module):
    def __init__(self, expansion, ni, nh, stride=1):
        nf,ni = nh*expansion,ni*expansion
        layers  = [conv_layer(ni, nh, 3, stride=stride),
                   conv_layer(nh, nf, 3, zero_bn=True, act=False)
        ] if expansion == 1 else [
                   conv_layer(ni, nh, 1),
                   conv_layer(nh, nh, 3, stride=stride),
                   conv_layer(nh, nf, 1, zero_bn=True, act=False)
        ]
        self.convs = nn.Sequential(*layers)
        # TODO: check whether act=True works better
        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)

    def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))

def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))

class MXResNet(nn.Sequential):
    def __init__(self, expansion, layers, c_in=3, c_out=1000):
        stem = []
        sizes = [c_in,32,64,64]  #modified per Grankin
        for i in range(3):
            stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
            #nf = filt_sz(c_in*9)
            #stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))
            #c_in = nf

        block_szs = [64//expansion,64,128,256,512]
        blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)
                  for i,l in enumerate(layers)]
        super().__init__(
            *stem,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *blocks,
            nn.AdaptiveAvgPool2d(1), Flatten(),
            nn.Linear(block_szs[-1]*expansion, c_out),
        )
        init_cnn(self)

    def _make_layer(self, expansion, ni, nf, blocks, stride):
        return nn.Sequential(
            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)
              for i in range(blocks)])

def mxresnet(expansion, n_layers, name, pretrained=False, **kwargs):
    model = MXResNet(expansion, n_layers, **kwargs)
    if pretrained: 
        #model.load_state_dict(model_zoo.load_url(model_urls[name]))
        print("No pretrained yet for MXResNet")
    return model

me = sys.modules[__name__]
for n,e,l in [
    [ 18 , 1, [2,2,2 ,2] ],
    [ 34 , 1, [3,4,6 ,3] ],
    [ 50 , 4, [3,4,6 ,3] ],
    [ 101, 4, [3,4,23,3] ],
    [ 152, 4, [3,8,36,3] ],
]:
    name = f'mxresnet{n}'
    setattr(me, name, partial(mxresnet, expansion=e, n_layers=l, name=name))

Mish activation loaded...
Mish activation loaded...


In [0]:
  
import torch.nn as nn
import torch.nn.functional as F  #(uncomment if needed,but you likely already have it)

#Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
#https://arxiv.org/abs/1908.08681v1
#implemented for PyTorch / FastAI by lessw2020 
#github: https://github.com/lessw2020/mish

class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x *( torch.tanh(F.softplus(x)))
        return x

In [0]:
import itertools as it
from torch.optim import Optimizer, Adam

class Lookahead(Optimizer):
    def __init__(self, base_optimizer,alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        self.optimizer = base_optimizer
        self.param_groups = self.optimizer.param_groups
        self.alpha = alpha
        self.k = k
        for group in self.param_groups:
            group["step_counter"] = 0
        self.slow_weights = [[p.clone().detach() for p in group['params']]
                                for group in self.param_groups]

        for w in it.chain(*self.slow_weights):
            w.requires_grad = False

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        loss = self.optimizer.step()
        for group,slow_weights in zip(self.param_groups,self.slow_weights):
            group['step_counter'] += 1
            if group['step_counter'] % self.k != 0:
                continue
            for p,q in zip(group['params'],slow_weights):
                if p.grad is None:
                    continue
                q.data.add_(self.alpha,p.data - q.data)
                p.data.copy_(q.data)
        return loss

In [11]:
res = []
num_epoch=5
learn = Learner(data, mxresnet50(c_out=10), wd=1e-2, 
               metrics=[accuracy, top_k_accuracy],
               bn_wd=False, true_wd=True,
               loss_func=LabelSmoothingCrossEntropy())

learn.fit_one_cycle(20, max_lr=3e-3)


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.084373,2.096703,0.278,0.778,02:07
1,1.879687,2.158034,0.342,0.818,02:07
2,1.739452,1.858424,0.42,0.864,02:07
3,1.653582,1.740463,0.494,0.896,02:07
4,1.534283,1.861979,0.452,0.884,02:07
5,1.416016,1.489798,0.59,0.946,02:07
6,1.323618,1.396875,0.642,0.936,02:07
7,1.26439,1.396464,0.636,0.94,02:07
8,1.194979,1.218947,0.698,0.948,02:07
9,1.102362,1.139388,0.742,0.972,02:09


In [0]:

loss, acc, topk = learn.validate()
res.append(acc.numpy())

In [13]:
learn = Learner(data, mxresnet50(c_out=10), wd=1e-2, 
               metrics=[accuracy, top_k_accuracy],
               bn_wd=False, true_wd=True,
               loss_func=LabelSmoothingCrossEntropy())

learn.fit_one_cycle(20, max_lr=3e-3)

epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.064888,1.982225,0.316,0.83,02:07
1,1.866454,2.302265,0.34,0.842,02:07
2,1.736936,2.342365,0.34,0.842,02:07
3,1.615955,1.690154,0.472,0.904,02:07
4,1.522823,1.616593,0.546,0.926,02:07
5,1.424253,1.55073,0.57,0.942,02:07
6,1.335092,1.368072,0.644,0.932,02:07
7,1.252069,1.342005,0.664,0.958,02:07
8,1.181628,1.234345,0.694,0.958,02:07
9,1.110938,1.143214,0.724,0.972,02:07


In [0]:

loss, acc, topk = learn.validate()
res.append(acc.numpy())

In [15]:
learn = Learner(data, mxresnet50(c_out=10), wd=1e-2, 
               metrics=[accuracy, top_k_accuracy],
               bn_wd=False, true_wd=True,
               loss_func=LabelSmoothingCrossEntropy())

learn.fit_one_cycle(20, max_lr=3e-3)

epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.077032,2.050649,0.304,0.768,02:07
1,1.861813,1.931837,0.348,0.846,02:07
2,1.7508,2.063384,0.364,0.826,02:07
3,1.635669,1.880313,0.432,0.9,02:07
4,1.558453,1.814712,0.468,0.908,02:07
5,1.442231,1.510067,0.556,0.942,02:07
6,1.351647,1.50793,0.616,0.926,02:07
7,1.263998,1.507027,0.578,0.932,02:07
8,1.197337,1.239119,0.702,0.954,02:07
9,1.133025,1.191332,0.72,0.976,02:07


In [0]:
loss, acc, topk = learn.validate()
res.append(acc.numpy())

In [17]:
learn = Learner(data, mxresnet50(c_out=10), wd=1e-2, 
               metrics=[accuracy, top_k_accuracy],
               bn_wd=False, true_wd=True,
               loss_func=LabelSmoothingCrossEntropy())

learn.fit_one_cycle(20, max_lr=3e-3)

epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.050308,2.015106,0.316,0.8,02:07
1,1.856873,2.004164,0.368,0.846,02:07
2,1.737028,3.149825,0.266,0.762,02:07
3,1.653702,2.115754,0.342,0.886,02:07
4,1.538797,1.528147,0.548,0.926,02:07
5,1.433369,1.517626,0.61,0.932,02:07
6,1.335768,1.557893,0.584,0.932,02:07
7,1.24449,1.243512,0.68,0.96,02:07
8,1.184286,1.143195,0.726,0.956,02:07
9,1.122705,1.13227,0.728,0.968,02:07


In [0]:
loss, acc, topk = learn.validate()
res.append(acc.numpy())

In [19]:
learn = Learner(data, mxresnet50(c_out=10), wd=1e-2, 
               metrics=[accuracy, top_k_accuracy],
               bn_wd=False, true_wd=True,
               loss_func=LabelSmoothingCrossEntropy())

learn.fit_one_cycle(20, max_lr=3e-3)

epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.086531,2.026362,0.298,0.784,02:07
1,1.908724,1.948479,0.362,0.844,02:07
2,1.760613,1.968992,0.412,0.878,02:07
3,1.653947,1.732773,0.482,0.888,02:07
4,1.537726,1.741035,0.496,0.914,02:07
5,1.456567,1.665006,0.524,0.906,02:07
6,1.322566,1.589853,0.558,0.94,02:07
7,1.262224,1.857657,0.498,0.918,02:07
8,1.20285,1.354937,0.656,0.966,02:07
9,1.11739,1.099138,0.72,0.984,02:07


In [0]:
loss, acc, topk = learn.validate()
res.append(acc.numpy())

In [21]:
np.mean(res)

0.83760005

In [22]:
np.std(res)

0.0070880107

In [23]:
res

[array(0.838, dtype=float32),
 array(0.832, dtype=float32),
 array(0.842, dtype=float32),
 array(0.828, dtype=float32),
 array(0.848, dtype=float32)]