Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx]get communication size between partitions #1224

Merged
merged 27 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
df7e650
[CLI] add CLI launcher
YuliangLiu0306 Apr 13, 2022
73753aa
Merge branch 'feature/cli' into main
YuliangLiu0306 Apr 13, 2022
80da77a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 15, 2022
551359c
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 18, 2022
a25697a
Revert "[CLI] add CLI launcher"
YuliangLiu0306 Apr 19, 2022
77b5704
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 19, 2022
e23d33e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 20, 2022
997c625
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 23, 2022
961d950
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
2deaa40
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
9ff217f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 28, 2022
501dc1a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 12, 2022
21e43fd
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 21, 2022
cbd4579
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 23, 2022
1443291
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 30, 2022
e627cf5
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 10, 2022
289316e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
689e047
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
0a83919
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 17, 2022
98c1ef9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 20, 2022
9a3af67
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 21, 2022
7700793
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 28, 2022
3c77d1f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 30, 2022
7c10323
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 4, 2022
11711d1
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 6, 2022
92dc0b0
[fx]get communication size between partitions.
YuliangLiu0306 Jul 7, 2022
6a41732
polish
YuliangLiu0306 Jul 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.fx
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from torch.fx._compatibility import compatibility


@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.

shape: torch.Size
dtype: torch.dtype
requires_grad: bool
stride: Tuple[int]
numel: int
# TODO: we can add a list of sharding spec here, and record the sharding
# behaviour by appending sharding spec into list.


def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride()
numel = result.numel()

return TensorMetadata(shape, dtype, requires_grad, stride, numel)


@compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
into the corresponding node.

Usage:
BATCH_SIZE = 2
DIM_IN = 4
DIM_OUT = 16
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
orig_output = model(input_sample)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)

for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel)

# output of above code is
# input_1 torch.float32 torch.Size([2, 4]) 8
# weight torch.float32 torch.Size([16, 4]) 64
# bias torch.float32 torch.Size([16]) 16
# linear torch.float32 torch.Size([2, 16]) 32
# output torch.float32 torch.Size([2, 16]) 32
Args:
module (GraphModule): The module to be executed

"""

def run_node(self, n: Node) -> Any:
result = super().run_node(n)

found_tensor = False

def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
return _extract_tensor_metadata(obj)
else:
return obj

meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
else:
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)

n.meta['type'] = type(result)
return result

def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.

Args:
*args (Tensor): the sample input.

Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)
27 changes: 27 additions & 0 deletions colossalai/fx/passes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from typing import Dict, Set
from torch.fx.node import Node, map_arg


def get_comm_size(parent_partition, child_partition):
YuliangLiu0306 marked this conversation as resolved.
Show resolved Hide resolved
"""Given two partitions (parent and child),
calculate the communication size between the two.
"""
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
parent_node_names = [n.name for n in parent_partition.graph.nodes]
for node in child_partition.graph.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
visited_nodes.add(n)
return comm_size
46 changes: 46 additions & 0 deletions tests/test_fx/test_comm_size_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size

MODEL_DIM = 16
BATCH_SIZE = 8
PIPELINE_SIZE = 2


class MLP(torch.nn.Module):

def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
return x


def test_comm_size_compute():
model = MLP(MODEL_DIM)
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
submodule_list = list(split_model.children())
comm_size = get_comm_size(submodule_list[0], submodule_list[1])
# the shape of tensor send from partition 0 to partition 1 is (8, 16)
assert comm_size == 128


if __name__ == '__main__':
test_comm_size_compute()
35 changes: 35 additions & 0 deletions tests/test_fx/test_meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata

BATCH_SIZE = 2
DIM_IN = 4
DIM_OUT = 16


def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
assert meta_info_spec.shape == orig_tensor.shape
assert meta_info_spec.dtype == orig_tensor.dtype
assert meta_info_spec.requires_grad == orig_tensor.requires_grad
assert meta_info_spec.stride == orig_tensor.stride()
assert meta_info_spec.numel == orig_tensor.numel()


def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
orig_output = model(input_sample)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
for node in gm.graph.nodes:
if node.op == 'placeholder':
meta_check(node.meta['tensor_meta'], input_sample)
if node.op == 'output':
meta_check(node.meta['tensor_meta'], orig_output)


if __name__ == '__main__':
test_meta_info_prop()