Skip to content

Commit

Permalink
[tensor] hijack addmm for colo tensor (#923)
Browse files Browse the repository at this point in the history
* hijack addmm for colo tensor

* fix bugs

* polish unit test

* polish comments
  • Loading branch information
ver217 committed May 9, 2022
1 parent 534afb0 commit 45b9124
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 8 deletions.
3 changes: 2 additions & 1 deletion colossalai/tensor/_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .element_wise import *
from .layernorm import colo_layernorm
from .loss import colo_cross_entropy
from .embedding import colo_embedding
from .embedding import colo_embedding
from .addmm import colo_addmm
115 changes: 115 additions & 0 deletions colossalai/tensor/_ops/addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from typing import Union
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv


def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
parallel_action = mat2.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_mm)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res

# mat1:S[1]
if mat1.is_gathered():
# Not splited yet.
assert divide(mat1.shape[-1], gpc.tensor_parallel_size) == mat2.size(0), \
'Invalid shapes in 1Drow forward: mat1={}, mat2={}. Expected last dim of input {}.'.format(
mat1.shape, mat2.shape, mat2.size(0) * gpc.tensor_parallel_size)
input_per_partition = split_forward_gather_backward(mat1.torch_tensor(), parallel_action.parallel_mode, dim=-1)
elif mat1.shard_pattern == ShardPattern.Col:
# Splited by 1Dcol
assert mat1.shape[-1] == mat2.size(0), \
'Invalid shapes in 1Drow forward: mat1={}, mat2={}. Expected last dim of input {}.'.format(
mat1.shape, mat2.shape, mat2.size(0))
input_per_partition = mat1.torch_tensor()
else:
raise NotImplementedError

# Output:P
partial_output = torch.mm(input_per_partition, mat2.torch_tensor())
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor.torch_tensor() + alpha * output
output = ColoTensor.init_from_torch_tensor(output)
return output


def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
# All-Gather(Output)
# mat1:B
parallel_action = mat2.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_mm)
if mat1.is_gathered():
# Not splited yet.
assert mat1.shape[-1] == mat2.size(0), \
'Invalid shapes in 1Dcol forward: mat1={}, mat2={}. Expected last dim of input {}.'.format(
mat1.shape, mat2.shape, mat2.size(0))
input_parallel = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)

# input:S[1]
assert input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1 and \
input_tensor.shard_pattern in [ShardPattern.Col, ShardPattern.Row], \
'Invalid bias spec for 1Dcol Linear op'

output_parallel = torch.addmm(input_tensor.torch_tensor(),
input_parallel,
mat2.torch_tensor(),
beta=beta,
alpha=alpha)

output = ColoTensor.init_from_torch_tensor(output_parallel)
out_parallel_action_list = [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]
output_spec = TensorSpec(out_parallel_action_list)
output.set_spec(output_spec, shard=False)
output.set_shard_pattern(ShardPattern.Col)
if parallel_action.gather_out:
# All-Gather(Output)
output.gather()
return output


@colo_op_impl(torch.addmm)
def colo_addmm(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, mat1, mat2 = tuple(
map(lambda t: t if isinstance(t, ColoTensor) else ColoTensor.init_from_torch_tensor(t), args[:3]))
beta = kwargs.get('beta', 1) if kwargs else 1
alpha = kwargs.get('alpha', 1) if kwargs else 1

# building the computing graph, inputs -> op
# if GraphGlobalEnv().graph_building:
# cur_op_node = GraphOpNode('linear', [weight, bias])
# cur_op_node.add_prev_tensor(input_tensor)

# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied
assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.init_from_torch_tensor(
torch.addbmm(input_tensor.torch_tensor(), mat1.torch_tensor(), mat2.torch_tensor(), beta=beta, alpha=alpha))
elif mat2.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = mat2.shard_spec.compute_patterns
if ComputePattern.TP1DRow_mm in compute_patterns:
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
elif ComputePattern.TP1DCol_mm in compute_patterns:
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError
else:
raise NotImplementedError

# building the computing graph, op -> output
# if GraphGlobalEnv().graph_building:
# cur_op_node.add_post_tensor(ret_tensor)

return ret_tensor
13 changes: 8 additions & 5 deletions colossalai/tensor/colo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,15 @@ def shard(self):
# Model Parameters
if self._shard_spec.num_action == 1:
parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
ComputePattern.TP1DCol_Embedding]:
if parallel_action.compute_pattern in [
ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm
]:
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \
ComputePattern.TP1DRow_Embedding]:
# We bind our ComputePattern on weight, which has to be transposed when linear().
self._shard_pattern = ShardPattern.Col
elif parallel_action.compute_pattern in [
ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm
]:
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
else:
Expand Down
6 changes: 4 additions & 2 deletions colossalai/tensor/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ class ComputePattern(Enum):
TP1DCol_Linear = 2
TP1DRow_Embedding = 3
TP1DCol_Embedding = 4
ZeRO = 5
DP = 6
TP1DRow_mm = 5
TP1DCol_mm = 6
ZeRO = 7
DP = 8


class ShardPattern(Enum):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_tensor/test_addmm_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import colossalai
import torch
import pytest
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.context import ParallelMode
from colossalai.utils.cuda import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial


class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""

def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.ones(nf))

def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x


def init_1d_row(model):
spec = TensorSpec(
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_mm, parallel_mode=ParallelMode.PARALLEL_1D)])
for n, p in model.colo_named_parameters():
if 'weight' in n:
p.set_spec(spec)


def init_1d_col(model):
spec = TensorSpec(
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_mm, parallel_mode=ParallelMode.PARALLEL_1D)])
for n, p in model.colo_named_parameters():
p.set_spec(spec)


def run_with_spec(spec_init_func):
with ColoInitContext(device=get_current_device()):
model = Conv1D(4, 16)
weight = model.weight.torch_tensor().clone()
bias = model.bias.torch_tensor().clone()
spec_init_func(model)
x = torch.rand(2, 16).cuda()
out = model(x)
assert torch.allclose(out.torch_tensor(), torch.addmm(bias, x, weight))


def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row)
run_with_spec(init_1d_col)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_addmm_1d(2)

0 comments on commit 45b9124

Please sign in to comment.