## Gradient Accumulation ON - FP32

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai import *
from fastai.vision import *
import os

In [None]:
gpu_device = 1
defaults.device = torch.device(f'cuda:{gpu_device}')
torch.cuda.set_device(gpu_device)

In [None]:
BS = 8
N_STEP = 4  # grad accumulation for n steps

In [None]:
path = untar_data(URLs.PETS); path

In [None]:
path.ls()

In [None]:
path_anno = path/'annotations'
path_img = path/'images'

In [None]:
fnames = get_image_files(path_img)
fnames[:5]

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
class AccumulateOptimWrapper(OptimWrapper):
    def step(self):          pass
    def zero_grad(self):      pass
    def real_step(self):      super().step()
    def real_zero_grad(self): super().zero_grad()
        
def acc_create_opt(self, lr:Floats, wd:Floats=0.):
        "Create optimizer with `lr` learning rate and `wd` weight decay."
        self.opt = AccumulateOptimWrapper.create(self.opt_func, lr, self.layer_groups,
                                         wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
        
@dataclass
class AccumulateStep(LearnerCallback):
    """
    Does accumlated step every nth step by accumulating gradients
    """
    def __init__(self, learn:Learner, n_step:int = 1):
        super().__init__(learn)
        self.n_step = n_step
 
    def on_train_begin(self, **kwargs):
        "check if loss is reduction"
        if self.loss_func.reduction == "mean":
             print("For better gradients consider 'reduction=sum'")
        
    def on_epoch_begin(self, **kwargs):
        "init samples and batches, change optimizer"
        self.acc_samples = 0
        self.acc_batches = 0
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "accumulate samples and batches"
        self.acc_samples += last_input.shape[0]
        self.acc_batches += 1
        print(f"At batch {self.acc_batches}")
        
    def on_backward_end(self, **kwargs):
        "step if number of desired batches accumulated, reset samples"
        if (self.acc_batches % self.n_step) == 0:
            for p in (self.learn.model.parameters()):
                if p.requires_grad: p.grad.div_(self.acc_samples)
    
            print(f"Stepping at batch: {self.acc_batches}")
            self.learn.opt.real_step()
            self.learn.opt.real_zero_grad()
            self.acc_samples = 0
    
    def on_epoch_end(self, **kwargs):
        "step the rest of the accumulated grads"
        self.learn.opt.real_step()
        self.learn.opt.real_zero_grad()

In [None]:
original_create_opt = Learner.create_opt
def turn_off_accumulation(): Learner.create_opt = original_create_opt
def turn_on_accumulation(): Learner.create_opt = acc_create_opt

In [None]:
seed_everything(2)

In [None]:
pat = re.compile(r'/([^/]+)_\d+.jpg$')

In [None]:
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=BS
                                  ).normalize(imagenet_stats)

In [None]:
data.show_batch(rows=2, figsize=(7,6))

## Training: resnet34

In [None]:
def get_learner():
    turn_on_accumulation()
    learn = create_cnn(data=data, arch=models.resnet34, metrics=error_rate,
                       callback_fns=[partial(AccumulateStep, n_step=N_STEP)])
    learn.loss_func = CrossEntropyFlat(reduction="sum")
    return learn

In [None]:
learn = get_learner() 
learn.lr_find() # pick lr
learn.recorder.plot()
learn = get_learner() 

In [None]:
learn.fit_one_cycle(4)

### Unfreezing, fine-tuning, and learning rates

In [None]:
learn.unfreeze()

In [None]:
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))

# Training: resnet50

In [None]:
seed_everything(2)

In [None]:
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(),
                                   size=299, bs=BS).normalize(imagenet_stats)

In [None]:
def get_learner():
    turn_on_accumulation()
    learn = create_cnn(data=data, arch=models.resnet50, metrics=error_rate,
                       callback_fns=[partial(AccumulateStep, n_step=N_STEP)])
    learn.loss_func = CrossEntropyFlat(reduction="sum")
    return learn

In [None]:
learn = get_learner() 
learn.lr_find() # pick lr
learn.recorder.plot()
learn = get_learner() 

In [None]:
learn.fit_one_cycle(5) 

### Unfreeze

In [None]:
learn.unfreeze()

In [None]:
learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))