-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Description
System Info
- `transformers` version: 4.45.2
- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
- Python version: 3.11.11
- Huggingface_hub version: 0.28.1
- Safetensors version: 0.5.2
- Accelerate version: 1.3.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hi,
I am trying to use FSDP for a 2.7B Mamba2 model on 4xH100 GPUs. I tried setting FSDP parameters and FSDP config in huggingface Trainer as followed. But I got around 70-80GB per gpu across 4xH100 given batchsize=1. My data's sequence length is around 1-2k. Even without parsing FSDP settings in the Trainer, I am able to train the 2.7B model in same setting with around 70-80GB on a single GPU. I used full_shard, but the GPUs memory consumption does not seem to be fully sharding params and optimizer.
# This is how I launch
# torchrun --nproc_per_node=4 ...
model = MambaLMHeadModel.from_pretrained(
"state-spaces/mamba2-2.7b",
dtype=torch.bfloat16,
device="cuda"
).to(torch.bfloat16)
trainer = MambaTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
args=TrainingArguments(
learning_rate=args.learning_rate,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=8,
gradient_accumulation_steps=accume_steps,
optim=args.optim,
output_dir=args.output_path,
logging_steps=2,
save_steps=2500,
remove_unused_columns=False,
report_to="wandb",
# max_steps= 2,
# Validate
eval_strategy="steps", # or "epoch"
eval_steps=20,
prediction_loss_only=True,
do_eval=True,
# Make sure to specify linear LR schedule
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.1,
adam_beta1=0.9,
adam_beta2=0.95,
max_grad_norm=1.0, # clip_grad=1.0
# # Hugging Face FSDP integration:
fsdp=["full_shard", "auto_wrap"],
# fsdp="shard_grad_op",
# fsdp=True,
fsdp_config={
"min_num_params": 100000,
# "transformer_layer_cls_to_wrap": ["modelMambaLMHeadModel", "MixerModel", "Block", "Mamba2"],
"use_orig_params": False,
"activation_checkpointing": False,
"mixed_precision": "bf16",
"sync_module_states": False,
"forward_prefetch": False,
# "backward_prefetch": "backward_post",
"cpu_offload":False,
# "xla": True,
"cpu_ram_efficient_loading":False,
"limit_all_gathers": True,
# "sharding_strategy": "FULL_SHARD",
}
# ddp_find_unused_parameters=False
# fsdp=False, # Disable Trainer's internal FSDP
# ddp_find_unused_parameters=False, # Disable unused parameter check
),
data_collator=data_collator,
)
print(f'trainer.model{trainer.model}')
print (f'[train_mamba.py] torch.distributed.get_rank(): {torch.distributed.get_rank()}')
# wandb.init(name=f"batch={args.batch_size} lr={args.learning_rate} accum={args.gradient_accumulation_steps}")
if torch.distributed.get_rank() == 0:
wandb.init(name=f"batch={args.batch_size} lr={args.learning_rate} accum={args.gradient_accumulation_steps}")
# 6) Train
# if torch.distributed.get_rank() == 0:
trainer.train()
I printed the trainer.model after the trainer, and it appears to be the following. I believe that this tell us that the model is not being wrap correctly.
MambaLMHeadModel(
(backbone): MixerModel(
(embedding): Embedding(256000, 4096)
(layers): ModuleList(
(0-55): 56 x Block(
(norm): RMSNorm()
(mixer): Mamba2(
(in_proj): Linear(in_features=4096, out_features=18560, bias=False)
(conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)
(act): SiLU()
(norm): RMSNorm()
(out_proj): Linear(in_features=8192, out_features=4096, bias=False)
)
)
)
(norm_f): RMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=256000, bias=False)
)
Then I try to use another method of directly wrapping my model then pass in with the trainer, and the printed model is like the following. Previously, when the model is loading from pretrain, each gpu will have around 8GB of memory and continuing to increase (maybe with optimizers, grad, etc being allocated). However, with this wrapping, the GPU's memory decrease by around half then gradually increase up to 70-80GB each. I tried including different fsdp config parameters, trying whole bunch of different things. Still, each GPU will take around 70-80GB each, which is the same as using 1 GPU to finetune. I also tried only include the wrapping of fsdp itself without setting fsdp in Trainer, but that gives me gradient shape mismatch error that I have no clues where that come from.
# Directly wrap the model
model = MambaLMHeadModel.from_pretrained(
args.model,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
).to(torch.bfloat16)
if torch.distributed.get_rank() == 0:
print(model)
local_rank = int(os.getenv("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2000000
)
model = FSDP (module=model,
auto_wrap_policy=my_auto_wrap_policy,
# sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
sharding_strategy=ShardingStrategy.FULL_SHARD,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_POST,
limit_all_gathers=True,
)
# HF Trainer:
trainer = MambaTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
args=TrainingArguments(
learning_rate=args.learning_rate,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=8,
gradient_accumulation_steps=accume_steps,
optim=args.optim,
output_dir=args.output_path,
logging_steps=2,
save_steps=2500,
remove_unused_columns=False,
report_to="wandb",
# max_steps= 2,
# Validate
eval_strategy="steps", # or "epoch"
eval_steps=20,
prediction_loss_only=True,
do_eval=True,
# Make sure to specify linear LR schedule
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.1,
adam_beta1=0.9,
adam_beta2=0.95,
max_grad_norm=1.0, # clip_grad=1.0
# # Hugging Face FSDP integration:
fsdp=["full_shard", "auto_wrap"],
# fsdp="shard_grad_op",
# fsdp=True,
fsdp_config={
"min_num_params": 100000,
# "transformer_layer_cls_to_wrap": ["modelMambaLMHeadModel", "MixerModel", "Block", "Mamba2"],
"use_orig_params": False,
"activation_checkpointing": False,
"mixed_precision": "bf16",
"sync_module_states": False,
"forward_prefetch": False,
# "backward_prefetch": "backward_post",
"cpu_offload":False,
# "xla": True,
"cpu_ram_efficient_loading":False,
"limit_all_gathers": True,
# "sharding_strategy": "FULL_SHARD",
}
# ddp_find_unused_parameters=False
# fsdp=False, # Disable Trainer's internal FSDP
# ddp_find_unused_parameters=False, # Disable unused parameter check
),
data_collator=data_collator,
)
print(f'trainer.model{trainer.model}')
print (f'[train_mamba.py] torch.distributed.get_rank(): {torch.distributed.get_rank()}')
# wandb.init(name=f"batch={args.batch_size} lr={args.learning_rate} accum={args.gradient_accumulation_steps}")
if torch.distributed.get_rank() == 0:
wandb.init(name=f"batch={args.batch_size} lr={args.learning_rate} accum={args.gradient_accumulation_steps}")
# Chrome trace
# trainer.add_callback(
# ProfilerCallback(trace_file="no_soup_trace_3steps.json")
# )
# trainer.train()
# 6) Train
# if torch.distributed.get_rank() == 0:
trainer.train()
Now printed model looks like this:
FullyShardedDataParallel(
(_fsdp_wrapped_module): MambaLMHeadModel(
(backbone): FullyShardedDataParallel(
(_fsdp_wrapped_module): MixerModel(
(embedding): FullyShardedDataParallel(
(_fsdp_wrapped_module): Embedding(50288, 2560)
)
(layers): ModuleList(
(0-63): 64 x Block(
(norm): RMSNorm()
(mixer): Mamba2(
(in_proj): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=2560, out_features=10576, bias=False)
)
(conv1d): Conv1d(5376, 5376, kernel_size=(4,), stride=(1,), padding=(3,), groups=5376)
(act): SiLU()
(norm): RMSNorm()
(out_proj): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=5120, out_features=2560, bias=False)
)
)
)
)
(norm_f): RMSNorm()
)
)
(lm_head): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=2560, out_features=50288, bias=False)
)
)
)
Expected behavior
Everything should work out of the box with huggingface trainer, but it is not.