In [1]:
#default_exp accelerate_fastai_integration.callback

# Accelerate Callback
> Support for using Hugging Face's Accelerate in `Learner`

In [3]:
#export
from fastcore.basics import store_attr, patch
from fastai.callback.core import Callback, CancelBackwardException
from fastai.distributed import DistributedDL
from fastai.learner import Learner
from fastai.optimizer import OptimWrapper
from fastai.torch_core import to_device, default_device

from accelerate import Accelerator

## Accelerate

[accelerate](https://huggingface.co/docs/accelerate/index) is a lightweight framework designed to handle device placement, dataloader configuration, and optimizer/schedulers so that the same code can work with a single GPU, multiple GPU's, and even TPU's. 

To use this integration, make sure accelerate is installed with:
```bash
pip install accelerate
```

In [4]:
#exporti
@patch
def _set_device(self:Learner, b):
    if hasattr(self, "accelerator"): return to_device(b, self.accelerator.device)
    else:
        model_device = torch.device(torch.cuda.current_device()) if next(self.model.parameters()).is_cuda else torch.device("cpu")
        dls_device = getattr(self.dls, 'device', default_device())
        if model_device == dls_device: return to_device(b, dls_device)
        else: return to_device(b, model_device)

In [10]:
#export
class AcceleratorCallback(Callback):
    "A Callback that handles model, dataloader, and optimizer capabilities in distributed systems with Accelerate. Accepts Accelerator configuration parameters."
    @delegates(Accelerator)
    def __init__(self, **kwargs): self.accelerator = Accelerator(**kwargs)
    
    def before_fit(self):
        "Tie `self.accelerator` to the learner and prepare the model and optimizer"
        self.learn.accelerator = self.accelerator
        self.learn.model = self.accelerator.prepare(self.learn.model)
        self.prepare_optimizer()
        
    def prepare_optimizer(self):
        "Prepares the optimizer for distributed training"
        opt = self.accelerator.prepare_optimizer(self.learn.opt)
        self.learn.opt = OptimWrapper(opt)
        self.learn.accelerator._optimizers.append(self.learn.opt)
        
    def prepare_dataloader(self):
        "Prepares the active dl for distributed training"
        if self.accelerator.num_processes > 1:
            self.learn.dl = DistributedDL(
                self.learn.dl,
                rank=self.accelerator.process_index,
                world_size=self.accelerator.num_processes
            )
    
    def before_train(self): self.prepare_dataloader()
    
    def before_validate(self): self.prepare_dataloader()
    
    def before_backward(self):
        "Call accelerator.backward"
        self.accelerator.backward(self.learn.loss_grad)
        raise CancelBackwardException()