Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
flybird11111 committed Apr 2, 2024
1 parent fc52763 commit 8f2db47
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
15 changes: 9 additions & 6 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ def __init__(
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
*args,
**kwargs,
):
super().__init__()
super().__init__(*args, **kwargs)

# Keep input parameters
self.in_features = in_features
Expand Down Expand Up @@ -509,7 +511,7 @@ def forward(self, input: Tensor) -> Tensor:
return output


class VocabParallelLMHead1D(PaddingParallelModule, Linear1D_Col):
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
Expand Down Expand Up @@ -570,10 +572,6 @@ def __init__(
new_out_features = out_features + multiple - (out_features % multiple)

super().__init__(
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
weight_A=weight,
bias_A=bias_,
in_features=in_features,
out_features=new_out_features,
bias=bias,
Expand All @@ -583,7 +581,12 @@ def __init__(
bias_=bias_,
*args,
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
weight_A=weight,
bias_A=bias_,
)

# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@


class ParallelModule(nn.Module, ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
Expand Down
9 changes: 0 additions & 9 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ def preprocess(self):
self.tie_weight = self.tie_weight_check()
return self.model

def tie_weight_check(self):
input_embedding = self.model.get_input_embeddings()
output_embedding = self.model.get_output_embeddings()
return (
input_embedding is not None
and output_embedding is not None
and id(input_embedding.weight) == id(output_embedding.weight)
)

def module_policy(self):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
Expand Down

0 comments on commit 8f2db47

Please sign in to comment.