Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] added new linear module handler #1616

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
139 changes: 139 additions & 0 deletions colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData
from typing import List, Dict
from .registry import operator_registry

__all__ = ['LinearModuleHandler']


class DotProductStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass


class MatVecStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass


class LinearProjectionStrategyGenerator(StrategyGenerator_V2):

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass

def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""TODO: to be implemented"""
pass

def validate(self, *args, **kwargs) -> bool:
"""TODO: to be implemented"""
pass


class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass


@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""

def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}

if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping

def post_process(self, strategy: ShardingStrategy_V2):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]

# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy


@operator_registry.register(F.linear)
class LinearFunctionHandler(NodeHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""

def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}

if self.node.args[2] is not None:
physical_bias_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG,
data=self.node.args[2]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping

def post_process(self, strategy: ShardingStrategy_V2):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]

# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy
44 changes: 35 additions & 9 deletions colossalai/auto_parallel/solver/op_handler/node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2


class NodeHandler(ABC):
Expand Down Expand Up @@ -36,8 +36,15 @@ def register_strategy(self) -> StrategiesVector:
for generator in self.strategy_generator:
strategies = generator.generate(operand_mapping)
self.strategies_vector.extend(strategies)

self.strategies_vector = map(self.post_process, self.strategies_vector)
return self.strategies_vector

def post_process(self, strategy: ShardingStrategy_V2):
# tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights
return strategy

@abstractmethod
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
"""
Expand All @@ -46,21 +53,40 @@ def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
pass

@abstractmethod
def get_operand_mapping(self) -> Dict[str, Operand]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"""
Returns the mapping between the logical operand name to its physical operands.
A logical operand is defined by the strategy generator, for example, a matrix multiplication
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
the module input and the physical operand for "other" is the module weight.
Returns the mapping between the logical operation data to its physical data.
A logical operation data is a data associated with an operation, which can be input and output. It is
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
is the module output.
Note that the operand name is specified by the StrategyGenerator object.

For example:

# for a linear layer
mapping = {
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
"other": Operand(name="weight", type=OperandType.PARAM),
"bias": Operand(name="bias", type=OperandType.PARAM)
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
pass


class ModuleHandler(NodeHandler):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

# set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
self.module = module
self.named_parameters = named_parameters
122 changes: 114 additions & 8 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import operator
import torch
from functools import reduce

from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List, Union, Tuple, Any
Expand Down Expand Up @@ -40,18 +44,35 @@ class ShardingStrategy:
input_shardings: List[ShardingSpec] = None


class OperandType(Enum):
class OperationDataType(Enum):
"""
An operand can come from the argument list of an operator or the parameter list of a module.
An operation can come from the argument list of an operator or the parameter list of a module.
"""
ARG = 0
PARAM = 1
OUTPUT = 2


@dataclass
class Operand:
class OperationData:
"""
OperationData is the data related to an operator, the data can be the operand or the output.

Args:
name (str): the name of the operation-related data
type (OperationDataType): the type of the operation data
data (torch.Tensor): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
name: str
type: OperandType
type: OperationDataType
data: torch.Tensor
logical_shape: Tuple[int] = None

def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None:
self.logical_shape = self.data.shape


@dataclass
Expand All @@ -69,6 +90,20 @@ class TrainCycleItem:
total: Any


class CommunicationType(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could unify all the communication related operation to CommSpec, which could give communication cost information using get_comm_cost api, and support runtime application with convert_spec_to_action. The resharding costs is estimated by ShapeConsistencyManager which uses CommSpec internally.

FWD_ALL_REDUCE = 0
BWD_ALL_REDUCE = 1


@dataclass
class CommunicationAction:
"""
The actions
"""
type: CommunicationType
mesh_dim: int


@dataclass
class ShardingStrategy_V2:
"""
Expand All @@ -86,12 +121,35 @@ class ShardingStrategy_V2:
strategy.(default to None)
"""
name: str
output_sharding_spec: ShardingSpec
sharding_specs: Dict[OperationData, ShardingSpec] = None
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
input_sharding_specs: Dict[Operand, ShardingSpec] = None
input_resharding_costs: Dict[Operand, List[float]] = None
input_resharding_costs: Dict[OperationData, List[float]] = None
communication_actions: Dict[OperationData, List[CommunicationAction]] = None

@property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
specs = {}
specs.update(self._get_sharding_spec(OperationDataType.ARG))
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
return specs

@property
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.ARG)

@property
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.PARAM)

@property
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.OUTPUT)

def _get_sharding_spec(self, operation_data_type: OperationDataType):
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
return specs


class StrategyGenerator_V2(ABC):
Expand All @@ -104,9 +162,57 @@ class StrategyGenerator_V2(ABC):
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh

def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Compute the communication cost involved in the forward and backward iteration.
"""

comm_cost = TrainCycleItem(fwd=0, bwd=0)

def _compute_and_add(data: OperationData, action: CommunicationAction):
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)

# compute the fwd
if action.type == CommunicationType.FWD_ALL_REDUCE:
comm_cost.fwd += cost
elif action.type == CommunicationType.BWD_ALL_REDUCE:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {action.type}")

# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, actions in strategy.communication_actions:
for action in actions:
_compute_and_add(operand, action)

# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy

@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the computation flops.
"""
pass

@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the memory cost in bytes.
"""
pass

@abstractmethod
def generate(self, operand_mapping: Dict[str, Operand]) -> List[ShardingStrategy_V2]:
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""
Generate all possible sharding strategies for this operation.
"""
pass

Expand Down