/
auto_shard.py
225 lines (188 loc) · 9.93 KB
/
auto_shard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List, Set
import torch
import torch.fx
from torch.fx.node import Node
def _get_count(param_count: Dict, node_name: str) -> int:
"""Identify different mutations of a given node name."""
# TODO(anj): This is not very stable since it is possible that the name
# may not be in the same format. Is there another way to identify nodes
# in a graph?
if node_name in param_count:
return param_count[node_name]
elif node_name.split("_")[0] in param_count:
return param_count[node_name.split("_")[0]]
else:
raise RuntimeError(f"Unable to find match between param {param_count} and node {node_name}")
def _create_shard_to_param_count(param_count: Dict, node_name_to_shard_id: Dict) -> Dict:
"""Utility to create a map from shard id to param count using existing state."""
shard_to_param_count: Dict[int, int] = {}
for node_name in node_name_to_shard_id.keys():
try:
count = _get_count(param_count, node_name)
except RuntimeError:
continue
if node_name_to_shard_id[node_name] in shard_to_param_count:
shard_to_param_count[node_name_to_shard_id[node_name]] += count
else:
shard_to_param_count[node_name_to_shard_id[node_name]] = count
return shard_to_param_count
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
nodes_so_far = []
param_count: Dict[str, int] = {}
shard_to_param_count = {}
# Find the total number of params in the model and
# the number of params per shard we are aiming for.
for name, module in traced_graph_module.named_modules():
name = name.replace(".", "_")
param_count[name] = sum([x.numel() for x in module.parameters()])
logging.info(f"Total number of params are {param_count['']}")
per_shard_param = param_count[""] // shard_count
logging.info(f"Per shard param count {per_shard_param}")
for node in traced_graph_module.graph.nodes:
if node.op == "placeholder":
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
elif node.op in ["get_attr", "call_function", "call_method", "call_module"]:
min_shard_id = shard_id
min_node_name = ""
# For each of the args of a given node, find the arg that is not the
# last node we traversed. This is to help us find skip connections
# across shards.
for arg in node.args:
# If the node has args that are inputs to the forward function, they
# may not have explicit names.
if not hasattr(arg, "name"):
continue
if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[-1]:
if node_name_to_shard_id[arg.name] < min_shard_id:
min_shard_id = node_name_to_shard_id[arg.name]
min_node_name = arg.name
# If there is an input that is not from the previous shard,
# we collapse all the shards in between to be part of 1 shard.
# and update the param count per shard accordingly.
if min_shard_id < shard_id:
for node_name in reversed(nodes_so_far):
node_name_to_shard_id[node_name] = min_shard_id
if node_name == min_node_name:
break
shard_id = min_shard_id
# TODO(anj-s): Find a way to raise an error early if this can cause OOM errors.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# Update state that is tracking node -> shard id and shard id -> param count.
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
# TODO(anj): This could just be an update, we don't need to recreate the map.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# If we have gone over the number of params per shard count that we want to
# achieve, we should add a new shard.
# The shard_id may not have been updated in the map if we are at a node that does not
# have params.
if shard_id in shard_to_param_count and shard_to_param_count[shard_id] > per_shard_param:
shard_id += 1
elif node.op == "output":
break
return node_name_to_shard_id
class _ExtendedLeafTracer(torch.fx.Tracer):
"""Tracer with an extended set of leaf nn.Modules."""
def __init__(self, leaf_modules: Set[torch.nn.Module]):
"""Initializes a new _ExtendedLeafTracer object.
Args:
leaf_modules: The set of extra nn.Modules instances which will not be traced
through but instead considered to be leaves.
"""
super().__init__()
self.leaf_modules = leaf_modules
def is_leaf_module(self, m: torch.nn.Module, model_qualified_name: str) -> bool:
return super().is_leaf_module(m, model_qualified_name) or m in self.leaf_modules
# TODO(ehotaj): Extend this method to wrap at the least granular level. One way to do
# would be to wrap the Module tree bottom up, first wrapping untracable children and
# only wrapping parents if they are also untracable.
def _trace(model: torch.nn.Module) -> torch.fx.GraphModule:
"""Traces the given model and automatically wraps untracable modules into leaves."""
leaf_modules = set()
tracer = _ExtendedLeafTracer(leaf_modules)
for name, module in model.named_modules():
# TODO(ehotaj): The default is_leaf_module includes everything in torch.nn.
# This means that some coarse modules like nn.TransformerEncoder are treated
# as leaves, not traced, and are unable to be sharded. We may want to extend our
# sharding code to trace through these modules as well.
if tracer.is_leaf_module(module, ""):
continue
try:
tracer.trace(module)
except (TypeError, torch.fx.proxy.TraceError):
leaf_modules.add(module)
tracer = _ExtendedLeafTracer(leaf_modules)
graph = tracer.trace(model)
return torch.fx.GraphModule(model, graph)
def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.GraphModule]:
"""Utility used to shard a model using torch.fx.
This function traces the model twice in an attempt to identify the
right cutpoints and then shard the model. In the first pass we calculate
the number of parameters as we are tracing the graph and mark nodes at
which we might want to create a new module. In the second pass we
modify the graph by inserting placeholders and output nodes to essentially
shard the graph.
We don't support skip connections between shards. This means that all
input and output is self contained within a given shard. A node from
shard 1 cannot be an input to a node from shard 3. We expect all inputs
to a given shard to be coming from the last node in the previous shard.
This means that we may not be able to shard models by the specified
`shard_count` mentioned by the user.
Args:
model (nn.Module): Model to be sharded as specified by the device count.
shard_count (int): Number of shards that we want to split the model into.
"""
module_list: List[torch.fx.GraphModule] = []
num_graphs = 0
new_graph = torch.fx.Graph() # type: ignore
env: Dict[str, Node] = {}
new_input_node = None
traced_graph_module = _trace(model)
# This is the first pass where we attempt to get a map of where
# we need to insert placeholder and output nodes.
node_name_to_shard_id = _split_nodes(traced_graph_module, shard_count=shard_count)
# dummy value which indicates that this is the first node.
prev_shard_id = 1000
prev_node = None
for node in traced_graph_module.graph.nodes:
# If the current node is in the next shard, we insert an output node.
# A new graph is created and a placeholder is added for the next shard.
if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]:
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_graph.output(env[prev_node.name])
num_graphs += 1
module_list.append(torch.fx.GraphModule(model, new_graph))
new_graph = torch.fx.Graph()
node_name = "placeholder" + str(num_graphs)
pl_node = new_graph.create_node("placeholder", node_name)
env[node_name] = pl_node
new_input_node = pl_node
if new_input_node is not None:
# Account for a placeholder in the new graph.
node.args = (new_input_node,)
new_input_node = None
if node.op in ["placeholder", "get_attr", "call_function", "call_method", "call_module"]:
# Copy the nodes from the existing graph to the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
elif node.op == "output":
# If this is the last node, we should add an output
# node and add the last graph to the list.
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_graph.output(env[prev_node.name])
module_list.append(torch.fx.GraphModule(model, new_graph))
break
prev_node = new_node
prev_shard_id = node_name_to_shard_id[node.name]
return module_list