In [None]:
#default_exp prepare

# PyTorch Preprocessors
> Module for preprocessing torch classes to prepare for various distributed environments

This module is what is essentially a barebones version of [Accelerate](https://github.com/huggingface/accelerate) but it only affects the outer-most layer of the modules for what is needed in these tests.

So for example dispatched dataloaders are not a part of this, nor affecting the underlying dataset.

In [2]:
#export
from pytorch_benchmark.imports import is_tpu_available, is_multigpu_available
from pytorch_benchmark.utils import get_device, get_rank

In [25]:
#export
import os, torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader

if is_tpu_available(check_device=False):
    import torch_xla.distributed.xla_multiprocessing as xmp

## Preprocessors

In [7]:
#export
def prepare_model(
    model:nn.Module, # A PyTorch model to wrap
    **kwargs
):
    "Prepares a model for distributed training. kwargs are sent to DDP"
    if is_tpu_available():
        return xmp.MpModelWrapper(model)
    elif is_multigpu_available():
        return DDP(model, device_ids=[get_rank()], output_device=get_rank())
    return model

In [9]:
#export
class OptimizerInterface(torch.optim.Optimizer):
    "Basic optimizer wrapper that performs the right step call for TPU"
    def __init__(self, optimizer):
        self.opt = optimizer

    @property
    def state(self): return self.opt.state

    @state.setter
    def state(self, state): self.opt.state = state

    @property
    def defaults(self): return self.opt.defaults

    @defaults.setter
    def defaults(self, defaults): self.opt.defaults = defaults

    def state_dict(self): 
        "Passthrough to state dict"
        return self.opt.state_dict()

    def zero_grad(self): 
        "Passthrough to zero_grad"
        return self.opt.zero_grad()

    def step(self, closure=None):
        "Passthrough unless on TPU then calls the right stepper"
        if is_tpu_available():
            xm.optimizer_step(self.opt, {})
        self.opt.step(closure)

In [15]:
#export
def prepare_optimizer(
    opt:torch.optim.Optimizer
):
    return OptimizerInterface(opt)

In [14]:
#export
class SchedulerInterface:
    "Wrapper to step the scheduler the right number of times"
    def __init__(self, scheduler, num_processes):
        self.scheduler = scheduler
        self.num_processes = num_processes

    def step(self, *args, **kwargs):
        "Passthrough to `scheduler.step` but will also step the right number of times"
        for _ in range(self.num_processes):
            if getattr(self.scheduler, "total_steps", 0) <= self.scheduler.last_epoch:
                self.scheduler.step(*args, **kwargs)

In [22]:
#export
def prepare_scheduler(
    sched:torch.optim.lr_scheduler._LRScheduler
):
    if is_tpu_available():
        num_processes = 8 # hard coded for my tests
    elif is_multigpu_available():
        num_processes = torch.cuda.device_count()
    else:
        num_processes = 1
    return SchedulerInterface(sched, num_processes)

In [None]:
#export
def _prepare_one(obj, first_pass=False):
    # first pass on preperation: DataLoader, model, optimizer
    if first_pass:
        if isinstance(obj, torch.nn.Module):
            return prepare_model(obj)
        elif isinstance(obj, torch.optim.Optimizer):
            return prepare_optimizer(obj)
    elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler):
        return prepare_scheduler(obj)
    return obj

In [None]:
#export
def prepare_modules(*modules):
    "Prepares a set of modules, supports only PyTorch models, optimizers, and schedulers"
    result = tuple(_prepare_one(obj, first_pass=True) for obj in modules)
    return tuple(_prepare_one(obj) for obj in result)

## Interfaces

The interface classes `prepare_modules` may wrap around

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
show_doc(OptimizerInterface)

In [None]:
show_doc(OptimizerInterface.state_dict)

In [None]:
show_doc(OptimizerInterface.step)

In [None]:
show_doc(OptimizerInterface.zero_grad)

In [None]:
show_doc(SchedulerInterface)

In [None]:
show_doc(SchedulerInterface.step)

In [30]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_imports.ipynb.
Converted 01_prepare.ipynb.
Converted index.ipynb.
Converted utils.ipynb.
