Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] added utils for broadcast operation #1665

Merged
merged 2 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 96 additions & 0 deletions colossalai/auto_parallel/solver/op_handler/broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
from enum import Enum, auto
from typing import List
from colossalai.tensor.sharding_spec import ShardingSpec

__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']


class BroadcastType(Enum):
EQUAL = auto()
PADDDING = auto()
MULTIPLE = auto()


def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool:
"""
Check if two shapes are broadcastable to each other.
"""
for s1, s2 in zip(shape1[::-1], shape2[::-1]):
if s1 == 1 or s2 == 1 or s1 == s2:
pass
else:
return False
return True


def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
dims = []
for s1, s2 in zip(shape1_reverse, shape2_reverse):
dims.append(max(s1, s2))

# append the remaining dims
dims.extend(shape1_reverse[min_common_dim:])
dims.extend(shape2_reverse[min_common_dim:])
return dims[::-1]


def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.

Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)

# track the dim and its broadcasting type
logical_dim_broadcast_info = {}

for i in range(logical_num_dims):
# get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1
phyiscal_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx]

if phyiscal_dim_idx >= 0:
physical_dim_size = physical_shape[phyiscal_dim_idx]

if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING

# generate the sharding spec for the physical shape
physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict

for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]

if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
pass
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim

physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)

return physical_sharding_spec
59 changes: 59 additions & 0 deletions tests/test_auto_parallel/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from colossalai.auto_parallel.solver.op_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh


def test_is_broadcastable():
x1 = torch.rand(4, 4, 8)
x2 = torch.rand(1, 8)
assert is_broadcastable(x1.shape, x2.shape)

x1 = torch.rand(4, 2, 8)
x2 = torch.rand(2, 8)
assert is_broadcastable(x1.shape, x2.shape)

x1 = torch.rand(4, 2, 8)
x2 = torch.rand(4, 8)
assert not is_broadcastable(x1.shape, x2.shape)


def test_get_broadcast_shape():
x1 = torch.rand(4, 4, 8)
x2 = torch.rand(1, 8)
assert get_broadcast_shape(x1.shape, x2.shape) == [4, 4, 8]

x1 = torch.rand(4, 2, 8)
x2 = torch.rand(2, 8)
assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8]

x1 = torch.rand(4, 2, 8)
x2 = torch.rand(8)
assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8]


def test_recover_sharding_spec_for_broadcast_shape():
x1 = torch.rand(4, 1, 8)
x2 = torch.rand(2, 8)

physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)

broadcast_shape = get_broadcast_shape(x1.shape, x2.shape)
logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh,
dim_partition_dict={
0: [0],
1: [1]
},
entire_shape=broadcast_shape)
physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1,
broadcast_shape, x1.shape)
print(physical_sharding_spec_for_x1)

assert physical_sharding_spec_for_x1.entire_shape == x1.shape
# dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore
assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]}
assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R']