Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor] hijack addmm for colo tensor #923

Merged
merged 4 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/tensor/_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .element_wise import *
from .layernorm import colo_layernorm
from .loss import colo_cross_entropy
from .embedding import colo_embedding
from .embedding import colo_embedding
from .addmm import colo_addmm
115 changes: 115 additions & 0 deletions colossalai/tensor/_ops/addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from typing import Union
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv


def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
parallel_action = mat2.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_mm)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res

# mat1:S[1]
if mat1.is_gathered():
# Not splited yet.
assert divide(mat1.shape[-1], gpc.tensor_parallel_size) == mat2.size(0), \
'Invalid shapes in 1Drow forward: mat1={}, mat2={}. Expected last dim of input {}.'.format(
mat1.shape, mat2.shape, mat2.size(0) * gpc.tensor_parallel_size)
input_per_partition = split_forward_gather_backward(mat1.torch_tensor(), parallel_action.parallel_mode, dim=-1)
elif mat1.shard_pattern == ShardPattern.Col:
# Splited by 1Dcol
assert mat1.shape[-1] == mat2.size(0), \
'Invalid shapes in 1Drow forward: mat1={}, mat2={}. Expected last dim of input {}.'.format(
mat1.shape, mat2.shape, mat2.size(0))
input_per_partition = mat1.torch_tensor()
else:
raise NotImplementedError

# Output:P
partial_output = torch.mm(input_per_partition, mat2.torch_tensor())
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor.torch_tensor() + alpha * output
output = ColoTensor.init_from_torch_tensor(output)
return output


def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
# All-Gather(Output)
# mat1:B
parallel_action = mat2.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_mm)
if mat1.is_gathered():
# Not splited yet.
assert mat1.shape[-1] == mat2.size(0), \
'Invalid shapes in 1Dcol forward: mat1={}, mat2={}. Expected last dim of input {}.'.format(
mat1.shape, mat2.shape, mat2.size(0))
input_parallel = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)

# input:S[1]
assert input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1 and \
input_tensor.shard_pattern in [ShardPattern.Col, ShardPattern.Row], \
'Invalid bias spec for 1Dcol Linear op'

output_parallel = torch.addmm(input_tensor.torch_tensor(),
input_parallel,
mat2.torch_tensor(),
beta=beta,
alpha=alpha)

output = ColoTensor.init_from_torch_tensor(output_parallel)
out_parallel_action_list = [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]
output_spec = TensorSpec(out_parallel_action_list)
output.set_spec(output_spec, shard=False)
output.set_shard_pattern(ShardPattern.Col)
if parallel_action.gather_out:
# All-Gather(Output)
output.gather()
return output


@colo_op_impl(torch.addmm)
def colo_addmm(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, mat1, mat2 = tuple(
map(lambda t: t if isinstance(t, ColoTensor) else ColoTensor.init_from_torch_tensor(t), args[:3]))
beta = kwargs.get('beta', 1) if kwargs else 1
alpha = kwargs.get('alpha', 1) if kwargs else 1

# building the computing graph, inputs -> op
# if GraphGlobalEnv().graph_building:
# cur_op_node = GraphOpNode('linear', [weight, bias])
# cur_op_node.add_prev_tensor(input_tensor)

# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied
assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.init_from_torch_tensor(
torch.addbmm(input_tensor.torch_tensor(), mat1.torch_tensor(), mat2.torch_tensor(), beta=beta, alpha=alpha))
elif mat2.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = mat2.shard_spec.compute_patterns
if ComputePattern.TP1DRow_mm in compute_patterns:
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
elif ComputePattern.TP1DCol_mm in compute_patterns:
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError
else:
raise NotImplementedError

# building the computing graph, op -> output
# if GraphGlobalEnv().graph_building:
# cur_op_node.add_post_tensor(ret_tensor)

return ret_tensor
13 changes: 8 additions & 5 deletions colossalai/tensor/colo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,15 @@ def shard(self):
# Model Parameters
if self._shard_spec.num_action == 1:
parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
ComputePattern.TP1DCol_Embedding]:
if parallel_action.compute_pattern in [
ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm
]:
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \
ComputePattern.TP1DRow_Embedding]:
# We bind our ComputePattern on weight, which has to be transposed when linear().
self._shard_pattern = ShardPattern.Col
elif parallel_action.compute_pattern in [
ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm
]:
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
else:
Expand Down
6 changes: 4 additions & 2 deletions colossalai/tensor/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ class ComputePattern(Enum):
TP1DCol_Linear = 2
TP1DRow_Embedding = 3
TP1DCol_Embedding = 4
ZeRO = 5
DP = 6
TP1DRow_mm = 5
TP1DCol_mm = 6
ZeRO = 7
DP = 8


class ShardPattern(Enum):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_tensor/test_addmm_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import colossalai
import torch
import pytest
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.context import ParallelMode
from colossalai.utils.cuda import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial


class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""

def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.ones(nf))

def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x


def init_1d_row(model):
spec = TensorSpec(
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_mm, parallel_mode=ParallelMode.PARALLEL_1D)])
for n, p in model.colo_named_parameters():
if 'weight' in n:
p.set_spec(spec)


def init_1d_col(model):
spec = TensorSpec(
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_mm, parallel_mode=ParallelMode.PARALLEL_1D)])
for n, p in model.colo_named_parameters():
p.set_spec(spec)


def run_with_spec(spec_init_func):
with ColoInitContext(device=get_current_device()):
model = Conv1D(4, 16)
weight = model.weight.torch_tensor().clone()
bias = model.bias.torch_tensor().clone()
spec_init_func(model)
x = torch.rand(2, 16).cuda()
out = model(x)
assert torch.allclose(out.torch_tensor(), torch.addmm(bias, x, weight))


def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row)
run_with_spec(init_1d_col)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_addmm_1d(2)