-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: HybridParallelOptimizer holds unsharded model parameters after sharding #5539
Comments
Hi, thanks for the issue. |
This happens only when sequence parallel is on and ZeRO is off. We are rebuilding the seq parallel API with ring attention etc., so I've set it to False in enable_all_optimization as a quick fix. |
@Edenzzzz , thank you for your time looking into this issue. I am not sure if this fix works. I tested with Edit: this is my plugin configuration used: plugin = HybridParallelPlugin(
tp_size=4,
pp_size=1,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=False,
enable_sequence_parallelism=False,
enable_sequence_overlap=False,
zero_stage=0,
precision="fp16",
initial_scale=1,
) |
Looks like ColossalAI/colossalai/shardformer/policies/bert.py Lines 39 to 51 in 341263d
ColossalAI/colossalai/shardformer/policies/gpt2.py Lines 32 to 43 in 341263d
Although all policies have the same resize logic, each model has different default vocab embedding size, so only bert and gpt2 in your tests need resizing embedding, which create a new one and fail: from transformers import AutoConfig
def test_vocab_size_divisible_to_tp_size(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name)
vocab_size = config.vocab_size
print(f"model {model_name} vocab_size: {vocab_size}. Need to resize embeddings for tp degree {tp_size}? {vocab_size % tp_size != 0}")
test_vocab_size_divisible_to_tp_size("gpt2", 8)
test_vocab_size_divisible_to_tp_size("bert-base-uncased", 8)
test_vocab_size_divisible_to_tp_size("facebook/opt-125m", 8)
test_vocab_size_divisible_to_tp_size("tiiuae/falcon-rw-1b", 8)
It creates a complete new
|
A quick potential patch is not to use HF's def resize_token_embedding_inplace(num_new_tokens: int, embedding: nn.Embedding):
# In-place resize of the token embeddings
embedding.num_embeddings = new_num_tokens
embedding.weight.data = nn.functional.pad(
embedding.weight.data,
(0, 0, 0, new_num_tokens - embedding.weight.size(0)),
"constant",
0,
)
# In policy
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
resize_token_embedding_inplace(new_vocab_size, self.model.get_input_embeddings())
# self.model.resize_token_embeddings(new_vocab_size)
return self.model @Edenzzzz Could you please check if it works? Thanks |
Maybe it is related to #5489 ? |
Thanks for the nice catch! This worked for both gpt2 and bert. Yes some fix appears to be in progress. Will touch base with them tomorrow. |
Nice to hear that the fix will be merged very soon. Thank you! |
🐛 Describe the bug
When using tensor parallelism, model parameters are sharded across GPUs to reduce its memory consumption and parallel execution.
However, the optimizer still holds unsharded model parameters, preventing the old unsharded parameters from being released, taking more memory.
Example code: (adopted from
examples/language/gpt2/hybridparallelism/finetune.py
)This also affects
MixedPrecisionOptimizer.master_to_working_map
andMixedPrecisionOptimizer.working_to_master_map
:Because of this it seems only a portion of parameters (ie. unsharded ones) only trained, as
MixedPrecisionOptimizer.step()
skips sharded parameters as gradients are not stored in mismatched unsharded parameters:ColossalAI/colossalai/amp/naive_amp/mixed_precision_optimizer.py
Lines 173 to 175 in df5e9c5
Environment
PyTorch 2.2.1 / CUDA 12.1
The text was updated successfully, but these errors were encountered: