Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] added solver option dataclass #1588

Merged
merged 1 commit into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion colossalai/auto_parallel/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *
from .options import SolverOptions

__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
__all__ = [
'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph',
'SolverOptions'
]
11 changes: 11 additions & 0 deletions colossalai/auto_parallel/solver/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass

__all__ = ['SolverOptions']


@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
fast: bool = False
28 changes: 21 additions & 7 deletions colossalai/auto_parallel/solver/strategies_constructor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
from .constants import *
Expand All @@ -11,9 +14,20 @@


class StrategiesConstructor:

def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.

Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""

def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
Expand Down Expand Up @@ -77,13 +91,13 @@ def build_strategies_and_cost(self):
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# For placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.

if self.solver_options['fast_mode']:
if self.solver_options.fast:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
Expand All @@ -97,12 +111,12 @@ def build_strategies_and_cost(self):

# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options['fast_mode']:
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
Expand Down Expand Up @@ -382,7 +396,7 @@ def build_strategies_and_cost(self):

# output node
if node.op == 'output':
if self.solver_options['fast_mode']:
if self.solver_options.fast:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
Expand Down
4 changes: 3 additions & 1 deletion tests/test_auto_parallel/test_cost_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pickletools import optimize
import torch
from torch.fx import GraphModule
import torch.nn as nn
Expand All @@ -10,6 +11,7 @@
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.options import SolverOptions
from copy import deepcopy


Expand Down Expand Up @@ -52,7 +54,7 @@ def test_cost_graph():
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

solver_options = {'fast_mode': True}
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()

Expand Down
5 changes: 2 additions & 3 deletions tests/test_auto_parallel/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from copy import deepcopy
from colossalai.auto_parallel.solver import Solver
from colossalai.auto_parallel.solver.options import SolverOptions


class ConvModel(nn.Module):
Expand Down Expand Up @@ -39,7 +40,6 @@ def test_solver():
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()

tracer = ColoTracer()
Expand All @@ -57,9 +57,8 @@ def test_solver():
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

solver_options = {'fast_mode': True}
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()

Expand Down
3 changes: 2 additions & 1 deletion tests/test_auto_parallel/test_strategies_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.options import SolverOptions
from copy import deepcopy


Expand Down Expand Up @@ -47,7 +48,7 @@ def test_strategies_constructor():
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

solver_options = {'fast_mode': True}
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)

assert strategies_constructor.leaf_strategies == []
Expand Down