Skip to content

Commit

Permalink
[booster] init module structure and definition
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Mar 8, 2023
1 parent b51bfec commit e0fd829
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 0 deletions.
5 changes: 5 additions & 0 deletions colossalai/booster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .accelerator import Accelerator
from .booster import Booster
from .environment_table import EnvironmentTable
from .plugin import Plugin
from .precision import Precision
12 changes: 12 additions & 0 deletions colossalai/booster/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import torch.nn as nn

__all__ = ['Accelerator']

class Accelerator:
def __init__(self, device: torch.device):
self.device = device

def setup_model(self, model: nn.Module) -> nn.Module:
# TODO: implement this method
pass
66 changes: 66 additions & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from .plugin import Plugin

__all__ = ['Booster']


class Booster:

def __init__(self,
device: Union[str, torch.device] = 'cuda',
precision: str = 'fp32',
grad_clipping_type: str = 'norm',
grad_clipping_value: float = 0.0,
plugin: Optional[Plugin] = None) -> None:
# TODO: implement this method
pass

def boost(
self, *args: Union[nn.Module, Optimizer, LRScheduler, DataLoader]
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
# TODO: implement this method
pass

def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO: implement this method
pass

def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[torch.Tensor], torch.Tensor],
optimizer: Optimizer,
return_loss: bool = True,
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
# TODO: implement this method
# run pipeline forward backward pass
# return loss or outputs if needed
pass

def no_sync(self, model: nn.Module) -> contextmanager:
# TODO: implement this method
pass

def save(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass

def load(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
18 changes: 18 additions & 0 deletions colossalai/booster/environment_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List

__all__ = ['EnvironmentTable']


class EnvironmentTable:

def __init__(self, intra_op_world_sizes: List[int]):
# TODO: implement this method
pass

@property
def is_master(self) -> bool:
# TODO: implement this method
pass

# TODO: implement more utility methods as given in
# https://github.com/hpcaitech/ColossalAI/issues/3051
46 changes: 46 additions & 0 deletions colossalai/booster/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from colossalai.device.device_mesh import DeviceMesh

__all__ = ['Plugin']


class Plugin:

@property
def supported_devices(self) -> List[torch.device]:
pass

@property
def supported_precisions(self) -> List[str]:
pass

@property
def control_precision(self) -> bool:
pass

@property
def control_device(self) -> bool:
pass

@property
def support_no_sync(self) -> bool:
pass

def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module:
pass

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
pass

def setup_dataloader(self, dataloader: DataLoader) -> DataLoader:
pass

@property
def device_mesh_shape(self) -> List[Tuple[int, ...]]:
pass
25 changes: 25 additions & 0 deletions colossalai/booster/precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import torch.nn as nn
from torch.optim import Optimizer

__all__ = ['Precision']


class Precision:

def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float):
self.precision_type = precision_type
self.grad_clipping_type = grad_clipping_type
self.grad_clipping_value = grad_clipping_value

def setup_model(self, model: nn.Module) -> nn.Module:
# TODO: implement this method
pass

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
# TODO: implement this method
# inject grad clipping and unscale loss
pass

def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
pass

0 comments on commit e0fd829

Please sign in to comment.