Skip to content

Commit

Permalink
[autoparallel] implemented linear projection strategy generator
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Sep 26, 2022
1 parent 702dbc5 commit 708904c
Show file tree
Hide file tree
Showing 7 changed files with 563 additions and 133 deletions.
66 changes: 23 additions & 43 deletions colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry

__all__ = ['LinearModuleHandler']


class DotProductStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass


class MatVecStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass


class LinearProjectionStrategyGenerator(StrategyGenerator_V2):

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass

def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""TODO: to be implemented"""
pass

def validate(self, *args, **kwargs) -> bool:
"""TODO: to be implemented"""
pass


class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']


@operator_registry.register(torch.nn.Linear)
Expand All @@ -49,9 +15,10 @@ class LinearModuleHandler(ModuleHandler):
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""

def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
Expand Down Expand Up @@ -97,9 +64,10 @@ class LinearFunctionHandler(NodeHandler):
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""

def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
Expand All @@ -108,17 +76,29 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)

# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG

physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}

if self.node.args[2] is not None:
# check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG,
type=data_type,
data=self.node.args[2]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
Expand Down
13 changes: 8 additions & 5 deletions colossalai/auto_parallel/solver/op_handler/node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData
from ..strategy import StrategyGenerator_V2


class NodeHandler(ABC):
Expand All @@ -26,14 +27,14 @@ def __init__(
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.strategy_generator = self.register_strategy_generator()

def register_strategy(self) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
operand_mapping = self.get_operand_mapping()
for generator in self.strategy_generator:
strategy_generators = self.get_strategy_generator()
operand_mapping = self.get_operation_data_mapping()
for generator in strategy_generators:
strategies = generator.generate(operand_mapping)
self.strategies_vector.extend(strategies)

Expand All @@ -46,7 +47,7 @@ def post_process(self, strategy: ShardingStrategy_V2):
return strategy

@abstractmethod
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
"""
Define which generators should be used by this NodeHandler object.
"""
Expand Down Expand Up @@ -81,6 +82,8 @@ class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

print("created")

# set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
Expand Down
86 changes: 4 additions & 82 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
from .constants import *
Expand Down Expand Up @@ -90,18 +91,12 @@ class TrainCycleItem:
total: Any


class CommunicationType(Enum):
FWD_ALL_REDUCE = 0
BWD_ALL_REDUCE = 1


@dataclass
class CommunicationAction:
class MemoryCost:
"""
The actions
"""
type: CommunicationType
mesh_dim: int
activation: int = 0
parameter: int = 0


@dataclass
Expand Down Expand Up @@ -152,79 +147,6 @@ def _get_sharding_spec(self, operation_data_type: OperationDataType):
return specs


class StrategyGenerator_V2(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""

def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh

def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Compute the communication cost involved in the forward and backward iteration.
"""

comm_cost = TrainCycleItem(fwd=0, bwd=0)

def _compute_and_add(data: OperationData, action: CommunicationAction):
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)

# compute the fwd
if action.type == CommunicationType.FWD_ALL_REDUCE:
comm_cost.fwd += cost
elif action.type == CommunicationType.BWD_ALL_REDUCE:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {action.type}")

# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, actions in strategy.communication_actions:
for action in actions:
_compute_and_add(operand, action)

# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy

@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the computation flops.
"""
pass

@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the memory cost in bytes.
"""
pass

@abstractmethod
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""
Generate all possible sharding strategies for this operation.
"""
pass

@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass


class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
Expand Down
7 changes: 7 additions & 0 deletions colossalai/auto_parallel/solver/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .strategy_generator import StrategyGenerator_V2
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator

__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator'
]

0 comments on commit 708904c

Please sign in to comment.