Skip to content

Commit

Permalink
resolve super init
Browse files Browse the repository at this point in the history
resolve super init

resolve super init
  • Loading branch information
flybird11111 committed Apr 3, 2024
1 parent a0ca44e commit 84cbb6a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 53 deletions.
5 changes: 0 additions & 5 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def from_native_module(
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)

# create the parallel module
padding_embedding = PaddingEmbedding(
Expand All @@ -229,7 +228,6 @@ def from_native_module(
padding_idx=padding_idx,
device=device,
weight=module.weight,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
*args,
**kwargs,
)
Expand Down Expand Up @@ -343,8 +341,6 @@ def from_native_module(
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]

make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)

# create the parallel module
vocab_embedding_1d = VocabParallelEmbedding1D(
num_embeddings=num_embeddings,
Expand All @@ -353,7 +349,6 @@ def from_native_module(
device=device,
process_group=process_group,
weight=module.weight,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
*args,
**kwargs,
)
Expand Down
43 changes: 6 additions & 37 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(weight=weight, bias_=bias_, **kwargs)

# Keep input parameters
self.in_features = in_features
Expand Down Expand Up @@ -489,7 +489,6 @@ def from_native_module(
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)

lm_head_linear = PaddingLMHead(
in_features=in_features,
Expand All @@ -498,7 +497,6 @@ def from_native_module(
device=device,
weight=module.weight,
bias_=module.bias,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
*args,
**kwargs,
)
Expand Down Expand Up @@ -551,7 +549,6 @@ def __init__(
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
# create weight and bias
Expand Down Expand Up @@ -579,12 +576,9 @@ def __init__(
process_group=process_group,
weight=weight,
bias_=bias_,
*args,
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
weight_A=weight,
bias_A=bias_,
)

# get the length of valid embeddings
Expand All @@ -599,7 +593,7 @@ def __init__(

@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
Expand All @@ -611,8 +605,6 @@ def from_native_module(
bias = module.bias is not None
device = module.weight.device

make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)

lm_head_linear = VocabParallelLMHead1D(
in_features=in_features,
out_features=out_features,
Expand All @@ -621,41 +613,18 @@ def from_native_module(
process_group=process_group,
weight=module.weight,
bias_=module.bias,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
*args,
**kwargs,
)

return lm_head_linear

def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)

# Set up backprop all-reduce.
input_parallel = input_

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
)
if self.skip_bias_add:
output, _ = super().forward(input_)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)

output = super().forward(input_)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = output[..., : self.old_num_embeddings]
else:
output = output_parallel
output = output[..., : self.num_valid_embeddings_local]

if self.skip_bias_add:
return output, self.bias
else:
return output
return output
21 changes: 10 additions & 11 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@


class ParallelModule(nn.Module, ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(self, **kwargs):
super().__init__()

@abstractmethod
def from_native_module(
Expand Down Expand Up @@ -176,21 +176,20 @@ def _load_from_state_dict(
unexpected_keys.append(key)


class PaddingParallelModule(nn.Module, ABC):
class PaddingParallelModule(ParallelModule):
def __init__(
self,
new_num_embeddings: int = None,
old_num_embeddings: int = None,
weight_A: Optional[nn.Parameter] = None,
bias_A: Optional[nn.Parameter] = None,
*args,
new_num_embeddings: int,
old_num_embeddings: int,
weight: Optional[nn.Parameter],
bias_: Optional[nn.Parameter] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.new_num_embeddings = new_num_embeddings
self.old_num_embeddings = old_num_embeddings
self.weight = weight_A
self.bias = bias_A
self.weight = weight
self.bias = bias_

if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
self.resize_embedding_weight()
Expand Down

0 comments on commit 84cbb6a

Please sign in to comment.