**Note that the name of the callback `AccumulateStepper` has been changed into `AccumulateScheduler`**

https://forums.fast.ai/t/accumulating-gradients/33219/90?u=hwasiti

https://github.com/fastai/fastai/blob/fbbc6f91e8e8e91ba0e3cc98ac148f6b26b9e041/fastai/train.py#L99-L134

In [1]:
import fastai
from fastai.vision import *

In [2]:
gpu_device = 0
defaults.device = torch.device(f'cuda:{gpu_device}')
torch.cuda.set_device(gpu_device)

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

In [4]:
path = untar_data(URLs.PETS)

In [5]:
path_anno = path/'annotations'
path_img = path/'images'
fnames = get_image_files(path_img)
pat = re.compile(r'/([^/]+)_\d+.jpg$')

In [43]:
# Simplified RunningBatchNorm
# 07_batchnorm.ipynb (fastai course v3 part2 2019)
class RunningBatchNorm2d(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        # I have added self.nf so that it can be represented when
        # printing the model in the extra_repr method below
        self.nf = nf
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('factor', tensor(0.))
        self.register_buffer('offset', tensor(0.))
        self.batch = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)
        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.batch += bs
        means = self.sums/self.count
        varns = (self.sqrs/self.count).sub_(means*means)
        if bool(self.batch < 20): varns.clamp_min_(0.01)
        self.factor = self.mults / (varns+self.eps).sqrt()
        self.offset = self.adds - means*self.factor
        
    def forward(self, x):
        if self.training: self.update_stats(x)
        return x*self.factor + self.offset
    
    def extra_repr(self):
        return '{nf}, mom={mom}, eps={eps}'.format(**self.__dict__)

In [44]:
class RunningBatchNorm1d(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        # I have added self.nf so that it can be represented when
        # printing the model in the extra_repr method below
        self.nf = nf
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('factor', tensor(0.))
        self.register_buffer('offset', tensor(0.))
        self.batch = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)
        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.batch += bs
        means = self.sums/self.count
        varns = (self.sqrs/self.count).sub_(means*means)
        if bool(self.batch < 20): varns.clamp_min_(0.01)
        self.factor = self.mults / (varns+self.eps).sqrt()
        self.offset = self.adds - means*self.factor
        
    def forward(self, x):
        if self.training: self.update_stats(x)
        return x*self.factor + self.offset
    
    def extra_repr(self):
        return '{nf}, mom={mom}, eps={eps}'.format(**self.__dict__)

### No Grad Acc (BS 64), No running BN

In [7]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=64
                                  ).normalize(imagenet_stats)
learn = cnn_learner(data,  models.resnet18, metrics=accuracy)
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.727655,0.312436,0.901894,00:38


In [8]:
data.batch_size

64

### No Grad Acc (BS 2), No running BN

In [11]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = cnn_learner(data,  models.resnet18, metrics=accuracy)
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,2.394068,1.312781,0.644114,01:13


### Naive Grad Acc (BS 2) x 32 steps, No running BN

In [10]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = cnn_learner(data,  models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateScheduler, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,5.188273,1.928466,0.734777,00:54


### No Grad Acc (BS 2), Running BN

In [52]:
def bn2rbn(bn):
    if isinstance(bn, nn.BatchNorm1d): rbn = RunningBatchNorm1d(bn.num_features, eps=bn.eps, mom=bn.momentum)
    elif isinstance(bn, nn.BatchNorm2d): rbn = RunningBatchNorm2d(bn.num_features, eps=bn.eps, mom=bn.momentum)
    
    rbn.weight = bn.weight
    rbn.bias = bn.bias
    return (rbn).to(bn.weight.device)

In [53]:
def convert_bn(list_mods, func=bn2rbn):
    for i in range(len(list_mods)):
        if isinstance(list_mods[i], bn_types):
            list_mods[i] = func(list_mods[i])
        elif list_mods[i].__class__.__name__ in ("Sequential", "BasicBlock"):
            list_mods[i] = nn.Sequential(*convert_bn(list(list_mods[i].children()), func))
    return list_mods

In [54]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = cnn_learner(data,  models.resnet18, metrics=accuracy)
# learn.loss_func = CrossEntropyFlat(reduction='sum')

In [56]:
learn.model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (rel

In [49]:
learn.summary()

Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [64, 112, 112]       9,408      False     
______________________________________________________________________
BatchNorm2d          [64, 112, 112]       128        True      
______________________________________________________________________
ReLU                 [64, 112, 112]       0          False     
______________________________________________________________________
MaxPool2d            [64, 56, 56]         0          False     
______________________________________________________________________
Conv2d               [64, 56, 56]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 56, 56]         128        True      
______________________________________________________________________
ReLU                 [64, 56, 56]         0          False     
______________________________________________________________

In [57]:
learn.model = nn.Sequential(*convert_bn(list(learn.model.children()), bn2rbn))

In [40]:
learn.model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): RunningBatchNorm2d(64, mom=0.1, eps=1e-05)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): RunningBatchNorm2d(64, mom=0.1, eps=1e-05)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): RunningBatchNorm2d(64, mom=0.1, eps=1e-05)
      )
      (1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): RunningBatchNorm2d(64, mom=0.1, eps=1e-05)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): RunningBatchNorm2d(64, mom=0.1, eps=1e-05)
      )

In [58]:
learn.summary()

RuntimeError: Given groups=1, weight of size [128, 64, 1, 1], expected input[1, 128, 28, 28] to have 64 channels, but got 128 channels instead

In [33]:
%debug

> [0;32m/home/haider/anaconda3/envs/fastai-v1/lib/python3.7/site-packages/torch/nn/modules/conv.py[0m(320)[0;36mforward[0;34m()[0m
[0;32m    318 [0;31m    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    319 [0;31m        return F.conv2d(input, self.weight, self.bias, self.stride,
[0m[0;32m--> 320 [0;31m                        self.padding, self.dilation, self.groups)
[0m[0;32m    321 [0;31m[0;34m[0m[0m
[0m[0;32m    322 [0;31m[0;34m[0m[0m
[0m


ipdb>  input


tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 

ipdb>  input.shape


torch.Size([1, 128, 28, 28])


ipdb>  exit


In [55]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time


KeyboardInterrupt: 

### GroupNorm

In [16]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = create_cnn(data,  models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateScheduler, n_step=32)])
# learn.loss_func = CrossEntropyFlat(reduction='sum')

In [6]:
groups = 64

def bn2group(bn):
    groupnorm = nn.GroupNorm(groups, bn.num_features, affine=True)
    groupnorm.weight = bn.weight
    groupnorm.bias = bn.bias
    groupnorm.eps = bn.eps
    return (groupnorm).to(bn.weight.device)

In [7]:
def convert_bn(list_mods, func=bn2group):
    for i in range(len(list_mods)):
        if isinstance(list_mods[i], bn_types):
            list_mods[i] = func(list_mods[i])
        elif list_mods[i].__class__.__name__ in ("Sequential", "BasicBlock"):
            list_mods[i] = nn.Sequential(*convert_bn(list(list_mods[i].children()), func))
    return list_mods

In [8]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = create_cnn(data,  models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [9]:
learn.model = nn.Sequential(*convert_bn(list(learn.model.children()), bn2group))

In [10]:
learn.freeze()

In [11]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
1,5.360059,4.734238,0.336942,01:14


### Resnet + GroupNorm

In [None]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [None]:
def change_all_BN(module):
    for i in range(5):
        atr = 'bn'+str(i)
        if hasattr(module, atr):
            setattr(module, atr, bn2group(getattr(module,atr)))

def wrap_BN(model):
    for i in range(len(model)):
        for j in range(len(model[i])):
            if isinstance(model[i][j], bn_types):
                model[i][j] = bn2group(model[i][j])
            elif model[i][j].__class__.__name__ == "Sequential":
                for k in range(len(model[i][j])):
                    if isinstance(model[i][j][k], bn_types):
                        model[i][j][k] = bn2group(model[i][j][k])
                    elif model[i][j][k].__class__.__name__ == "BasicBlock":
                        change_all_BN(model[i][j][k])
                        if hasattr(model[i][j][k],'downsample'):
                            if model[i][j][k].downsample is not None:
                                for l in range(len(model[i][j][k].downsample)):
                                     if isinstance(model[i][j][k].downsample[l], bn_types):
                                        model[i][j][k].downsample[l] = bn2group(model[i][j][k].downsample[l])
    

In [None]:
wrap_BN(learn.model)

In [None]:
learn.freeze()

In [None]:
learn.fit(1)

### Resnet + GroupNorm (No Acc)

In [None]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=2
                                  ).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet18, metrics=accuracy)

In [None]:
wrap_BN(learn.model)
learn.freeze()
learn.fit(1)

### Resnet + GroupNorm (No Acc) bs = 1

In [None]:
seed_everything(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=1
                                  ).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet18, metrics=accuracy)

In [None]:
wrap_BN(learn.model)
learn.freeze()
learn.fit(1)