In [None]:
#default_exp accelerate_fastai_integration.core

# Core
> Notebook containing patches to the main fastai library to make it compatible with Accelerate

In [29]:
#export
import torch
from collections import defaultdict
from fastcore.basics import patch
from fastai.learner import AvgSmoothLoss, Learner
from fastai.optimizer import Optimizer, OptimWrapper, _convert_params, pytorch_hp_map, _update, detuplify_pg
from fastai.torch_core import to_device, default_device, to_detach

## Metrics

`fastai` already gathers tensors across devices in most situations. The one exception to this is `AvgSmoothLoss`, so for now this is a hack to set it to `True`, so that when we record our metrics once it is done then

In [30]:
#export
@patch
def accumulate(self:AvgSmoothLoss, learn):
    self.count += 1
    self.val = torch.lerp(to_detach(learn.loss.mean(), gather=True), self.val, self.beta)

## Optimizer

This section adds two small patches to existing optimizer functions, `Optimizer.step` and `OptimWrapper.__init__`.

The first was needed as Accelerate expects to be able to pass in a `closure` argument for some Optimizers. @ilovescience is eventually working on bringing in closure support to fastai, so this will be solved once that's done.

The second is a small change to `OptimWrapper` to allow it to accept an already existing torch optimizer. This is so we can perform `accelerator.prepare()` on an existing optimizer and then bring it back into the state fastai expects

In [20]:
#export
@patch
def step(self:Optimizer, closure=None):
    for p,pg,state,hyper in self.all_params(with_grad=True):
        for cb in self.cbs: state = _update(state, cb(p, **{**state, **hyper}))
        self.state[p] = state

In [22]:
#export
@patch
def __init__(self:OptimWrapper, params, opt, hp_map=None, convert_groups=True, **kwargs):
    if callable(opt):
        self.opt = opt(_convert_params(params), **kwargs) if convert_groups else opt(params, **kwargs)
    else:
        self.opt = opt
    if hp_map is None: hp_map = pytorch_hp_map
    self.fwd_map = {k: hp_map[k] if k in hp_map else k for k in detuplify_pg(self.opt.param_groups[0]).keys()}
    self.bwd_map = {v:k for k,v in self.fwd_map.items()}
    self.state = defaultdict(dict, {})
    self.frozen_idx = 0

## Learner

Finally, we have adjustments to `Learner`. There is a change in the inner training loop (`_do_one_batch`) to use `accelerator.backward()` for propagation, a new `gather` function that can be helpful when trying to gather tensors across devices, and a new `_set_device` implementation that uses `accelerator.device` if present.

In [23]:
#export
@patch
def gather(self:Learner, *items):
    "Gathers a tensor or list of tensors across all devices"
    return self.acelerator.gather(items)

In [26]:
#export
@patch
def _set_device(self:Learner, b):
    if hasattr(self, "accelerator"):
        return to_device(b, self.accelerator.device)
    else:
        model_device = torch.device(torch.cuda.current_device()) if next(self.model.parameters()).is_cuda else torch.device('cpu')
        dls_device = getattr(self.dls, 'device', default_device())
        if model_device == dls_device: return to_device(b, dls_device)
        else: return to_device(b, model_device)

In [27]:
#export
@patch
def _do_one_batch(self:Learner):
    self.pred = self.model(*self.xb)
    self('after_pred')
    if len(self.yb):
        self.loss_grad = self.loss_func(self.pred, *self.yb)
        self.loss = self.loss_grad.clone()
    self('after_loss')
    if not self.training or not len(self.yb): return
    self('before_backward')
    if hasattr(self, 'accelerator'):
        self.accelerator.backward(self.loss_grad)
    else:
        self.loss_grad.backward()
    self._with_events(self.opt.step, 'step', CancelStepException)
    self.opt.zero_grad()

In [2]:
from nbdev.export import notebook2script
notebook2script("00_core.ipynb")

Converted 00_core.ipynb.
