Skip to content

Commit

Permalink
[autoparallel] refactored the data structure for sharding strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Sep 20, 2022
1 parent eac1b79 commit 8831782
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from pickle import NONE
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
from .constants import *

Expand Down Expand Up @@ -37,6 +38,45 @@ class ShardingStrategy:
input_shardings: List[ShardingSpec] = None


@dataclass
class TrainCycleItem:
"""
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
in a training iteration.
Args:
fwd (Any): the item for the forward pass
bwd (Any): the item for the backward pass
"""
fwd: Any
bwd: Any


@dataclass
class ShardingStrategy_V2:
"""
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
Args:
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
"""
name: str
output_sharding_spec: ShardingSpec
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
input_sharding_specs: List[ShardingSpec] = None
input_resharding_costs: Dict[Node, List[float]] = None


class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
Expand Down

0 comments on commit 8831782

Please sign in to comment.