In [1]:
#default_exp accelerate_fastai_integration.callback

# Accelerate Callback
> The main Callback implementation for Accelerate

In [9]:
#export
from fastcore.basics import store_attr, patch
from fastai.callback.core import Callback, CancelBackwardException
from fastai.distributed import DistributedDL
from fastai.optimizer import OptimWrapper
from fastai.torch_core import to_device, default_device

from accelerate import Accelerator

In [None]:
#export
@patch
def _set_device(self:Learner, b):
    if hasattr(self, "accelerator"): return to_device(b, self.accelerator.device)
    else:
        model_device = torch.device(torch.cuda.current_device()) if next(self.model.parameters()).is_cuda else torch.device("cpu")
        dls_device = getattr(self.dls, 'device', default_device())
        if model_device == dls_device: return to_device(b, dls_device)
        else: return to_device(b, model_device)

In [10]:
#export
class AcceleratorCallback(Callback):
    "A Callback that handles model, dataloader, and optimizer compatibility with Accelerate"
    def __init__(self, accelerator:Accelerator):
        store_attr()
    
    def before_fit(self):
        "Tie `self.accelerator` to the learner and prepare the model and optimizer"
        self.learn.accelerator = self.accelerator
        self.learn.model = self.accelerator.prepare(self.learn.model)
        opt = self.accelerator.prepare_optimizer(self.learn.opt)
        # Does this maintain the layer groups?
        self.learn.opt = OptimWrapper(self.learn.model.parameters(), opt)
        self.learn.accelerator._optimizers.append(self.learn.opt)
        
    @staticmethod
    def _prepare_dataloader(dataloader, accelerator):
        "Prepares a single DistributedDL"
        return DistributedDL(
            dataloader,
            rank=accelerator.process_index,
            world_size=accelerator.num_processes
        )
    
    def before_train(self):
        if self.accelerator.num_processes > 1:
            self.learn.dl = self._prepare_dataloader(self.learn.dl, self.accelerator)
    
    def before_validate(self):
        if self.accelerator.num_processes > 1:
            self.learn.dl = self._prepare_dataloader(self.learn.dl, self.accelerator)
    
    def before_backward(self):
        "Call accelerator.backward"
        self.accelerator.backward(self.learn.loss_grad)
        raise CancelBackwardException()

In [1]:
from nbdev.export import notebook2script
notebook2script("01_callback.ipynb")

Converted 01_callback.ipynb.
