Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about training scripts #4

Closed
421zuoduan opened this issue Mar 18, 2024 · 7 comments
Closed

Question about training scripts #4

421zuoduan opened this issue Mar 18, 2024 · 7 comments

Comments

@421zuoduan
Copy link

Hi,

Thanks for the excellent work!

I'm currently facing an issue while running LLaVA-1.5 on our 8x3090 server. Specifically, when setting freeze_backbone=True, freeze_mm_mlp_adapter=True, and tune_mm_mlp_adapter=False, we encounter an Out of Memory error.

Here's the command we're using:

deepspeed ha_dpo/models/llava-v1_5/train_dpo.py \
    --lora_enable False \
    --mm_projector_lr 2e-5 \
    --deepspeed ha_dpo/models/llava-v1_5/scripts/zero3.json \
    --model_name_or_path /home/user/model/llava-v1.5-7b \
    --version v1 \
    --vg_path ha_dpo/data/VG \
    --desc_data_path ha_dpo/data/hadpo/llava-v1.5/desc_data.json \
    --pope_data_path ha_dpo/data/hadpo/llava-v1.5/pope_data.json \
    --vision_tower /home/user/model/clip-vit-large-patch14-336 \
    --freeze_backbone True \
    --tune_mm_mlp_adapter False \
    --freeze_mm_mlp_adapter True \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ha_dpo/models/llava-v1_5/checkpoints/llava-origin \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-6 \
    --weight_decay 0. \
    --warmup_steps 0 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb \
    --run_name "llava-v1.5" \
    --beta 0.1

Q1: In my understanding, I've frozen the backbone, projector, and vision tower (the vision encoder), and I haven't utilized LoRA training. Therefore, I wouldn't expect this setup to consume an excessive amount of memory. What could be causing this issue?

Q2: Furthermore, sometimes we encounter the "killing subprocess" issue after loading the model parameters. What could be causing this problem?

Q3: As I read from code, do I need to revise train_dpo.py in order to train lm_head? Thank you

log

(hadpo) user@zhu2:~/HA-DPO$ sh llava_finetune.sh
[2024-03-18 17:34:08,131] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:09,352] [WARNING] [runner.py:196:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2024-03-18 17:34:09,352] [INFO] [runner.py:555:main] cmd = /home/user/anaconda/envs/hadpo/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None ha_dpo/models/llava-v1_5/train_dpo.py --lora_enable False --mm_projector_lr 2e-5 --deepspeed ha_dpo/models/llava-v1_5/scripts/zero3.json --model_name_or_path /home/user/model/llava-v1.5-7b --version v1 --vg_path ha_dpo/data/VG --desc_data_path ha_dpo/data/hadpo/llava-v1.5/desc_data.json --pope_data_path ha_dpo/data/hadpo/llava-v1.5/pope_data.json --vision_tower /home/user/model/clip-vit-large-patch14-336 --freeze_backbone True --tune_mm_mlp_adapter False --freeze_mm_mlp_adapter True --mm_projector_type mlp2x_gelu --mm_vision_select_layer -2 --mm_use_im_start_end False --mm_use_im_patch_token False --image_aspect_ratio pad --group_by_modality_length True --bf16 True --output_dir ha_dpo/models/llava-v1_5/checkpoints/llava-origin --num_train_epochs 1 --per_device_train_batch_size 16 --per_device_eval_batch_size 4 --gradient_accumulation_steps 4 --evaluation_strategy no --save_strategy steps --save_steps 50000 --save_total_limit 1 --learning_rate 2e-6 --weight_decay 0. --warmup_steps 0 --lr_scheduler_type cosine --logging_steps 1 --tf32 True --model_max_length 2048 --gradient_checkpointing True --dataloader_num_workers 4 --lazy_preprocess True --report_to wandb --run_name llava-v1.5 --beta 0.1
[2024-03-18 17:34:10,808] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:11,728] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]}
[2024-03-18 17:34:11,728] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=8, node_rank=0
[2024-03-18 17:34:11,728] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]})
[2024-03-18 17:34:11,728] [INFO] [launch.py:163:main] dist_world_size=8
[2024-03-18 17:34:11,728] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
[2024-03-18 17:34:13,747] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,863] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,865] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,872] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,875] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,912] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,916] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-18 17:34:13,926] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:29<00:00, 14.58s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:29<00:00, 14.58s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.42s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.19s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.14s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.21s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.18s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.31s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:13<00:00,  6.74s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:12<00:00,  6.45s/it]
Loading checkpoint shards:  50%|█████████████████████████████████████████████████████████████                                                             | 1/2 [00:10<00:10, 10.55s/it][2024-03-18 17:36:48,749] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-03-18 17:36:48,749] [INFO] [comm.py:594:init_distributed] cdb=None
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:17<00:00,  8.73s/it]
[2024-03-18 17:36:50,976] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-03-18 17:36:50,977] [INFO] [comm.py:594:init_distributed] cdb=None
Loading checkpoint shards:  50%|█████████████████████████████████████████████████████████████                                                             | 1/2 [00:16<00:16, 16.74s/it][2024-03-18 17:36:53,999] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191881
[2024-03-18 17:36:53,999] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191882
[2024-03-18 17:36:55,767] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-03-18 17:36:55,768] [INFO] [comm.py:594:init_distributed] cdb=None
[2024-03-18 17:36:57,925] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191883
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:21<00:00, 10.75s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:22<00:00, 11.30s/it]
[2024-03-18 17:37:01,179] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191884
[2024-03-18 17:37:03,516] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-03-18 17:37:03,516] [INFO] [comm.py:594:init_distributed] cdb=None
[2024-03-18 17:37:03,518] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-03-18 17:37:03,518] [INFO] [comm.py:594:init_distributed] cdb=None
[2024-03-18 17:37:05,097] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191886
[2024-03-18 17:37:08,367] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191891
[2024-03-18 17:37:11,961] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191892
[2024-03-18 17:37:12,008] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 191950
[2024-03-18 17:37:15,578] [ERROR] [launch.py:321:sigkill_handler] ['/home/user/anaconda/envs/hadpo/bin/python', '-u', 'ha_dpo/models/llava-v1_5/train_dpo.py', '--local_rank=7', '--lora_enable', 'False', '--mm_projector_lr', '2e-5', '--deepspeed', 'ha_dpo/models/llava-v1_5/scripts/zero3.json', '--model_name_or_path', '/home/user/model/llava-v1.5-7b', '--version', 'v1', '--vg_path', 'ha_dpo/data/VG', '--desc_data_path', 'ha_dpo/data/hadpo/llava-v1.5/desc_data.json', '--pope_data_path', 'ha_dpo/data/hadpo/llava-v1.5/pope_data.json', '--vision_tower', '/home/user/model/clip-vit-large-patch14-336', '--freeze_backbone', 'True', '--tune_mm_mlp_adapter', 'False', '--freeze_mm_mlp_adapter', 'True', '--mm_projector_type', 'mlp2x_gelu', '--mm_vision_select_layer', '-2', '--mm_use_im_start_end', 'False', '--mm_use_im_patch_token', 'False', '--image_aspect_ratio', 'pad', '--group_by_modality_length', 'True', '--bf16', 'True', '--output_dir', 'ha_dpo/models/llava-v1_5/checkpoints/llava-origin', '--num_train_epochs', '1', '--per_device_train_batch_size', '16', '--per_device_eval_batch_size', '4', '--gradient_accumulation_steps', '4', '--evaluation_strategy', 'no', '--save_strategy', 'steps', '--save_steps', '50000', '--save_total_limit', '1', '--learning_rate', '2e-6', '--weight_decay', '0.', '--warmup_steps', '0', '--lr_scheduler_type', 'cosine', '--logging_steps', '1', '--tf32', 'True', '--model_max_length', '2048', '--gradient_checkpointing', 'True', '--dataloader_num_workers', '4', '--lazy_preprocess', 'True', '--report_to', 'wandb', '--run_name', 'llava-v1.5', '--beta', '0.1'] exits with return code = -15
@JulioZhao97
Copy link
Collaborator

JulioZhao97 commented Mar 19, 2024

Hello, thanks for your interest in our work!
I carefully checked memory usage using your settings:

  1. freeze backbone
    企业微信截图_423e1fbc-d739-415b-8baa-b58fef40154b
  2. freeze backbone, mm_mlp_adapter, language model, only fine-tunes llm head
    企业微信截图_538d6351-fcd0-4614-bc39-5e6352aa3b52
    It seems to me that it already takes at least 60GB of memory to put 2 model instances (reference model and policy model) into memory, even if only llm head is fine-tuned. In this case, 24GB 3090 is not sufficient.

To utilize HA-DPO in your case, I suggest the following:

  1. utilizing quantization strategies, such as 4-bit training provided in the official LLaVA script.
  2. instantiate only one model, to optimize memory usage, you can only instantiate only one model instead of two, and using peft.disable_adapter or peft.get_base_model() to get reference model output. This can cut memory usage in half. For specifics refer to the ha_dpo/trainer/LlavaDPOTrainer.py
  3. using FSDP, you can use FSDP instead of DDP which can split one model into 8 GPUs, this can largely reduce memory usage
  4. check possible errors, the NCCL backend in DeepSpeed not yet implemented is strange to me and does not occur in my experiment, could you please check deepspeed, torch, accelerate version again?
  5. The killing subprocess code=15 is also very strange, maybe related to memory usage.
  6. I checked the script, if you set backbone and mm_mlp_adapter to freeze and lora to False, only llm head will be fine-tuned.
    If you succeed in implementing HA-DPO in 3090, please let me know!

@421zuoduan
Copy link
Author

I wanna express my gratitude for your assistance. I've made some initial progress in alleviating OOM issue. I've successfully trained head on an 8x3090 server. However, I'm looking to further optimize GPU memory usage. I have a few inquiries regarding the solution you provided:

I couldn't locate the file ha_dpo/trainer/LlavaDPOTrainer.py. I presume it might be either llava_dpo_trainer.py or base_dpo_trainer.py. Upon careful examination of the code, I noticed the latter initializes self.ref_model. If I were to utilize PEFT's get_base_model method, does this mean that more parameters are loaded into vmem? If so, what would be the appropriate way to load the reference model? im new starter with PEFT and transformers, can u provide examples for this?

I also considered using ha_dpo/models/llava-v1_5/llava/model/language_model/llava_llama.py and its LlavaLlamaForCausalLM.get_model to obtain the reference model through model.get_model. However, would this potentially utilize the same memory space as the model, leading to errors?

I've implemented ZeRO3. Does this imply that I'm already utilizing FSDP?

Your insights and guidance would be greatly appreciated. Thank you for your time and assistance.

@421zuoduan
Copy link
Author

Furthermore, I encountered another issue. In the train_dpo.py file, I noticed that in the on_train_end method of the SaverCallback class, the condition for saving parameters is set as isinstance(kwargs['model'], PeftModelForCausalLM). However, during testing, I found that kwargs['model'] is always LlavaLlamaForCausalLM, which results in parameters not being saved after training completes. i have tried to replace PeftModelForCausalLM with LlavaLlamaForCausalLM and saved successfully

@JulioZhao97
Copy link
Collaborator

I wanna express my gratitude for your assistance. I've made some initial progress in alleviating OOM issue. I've successfully trained head on an 8x3090 server. However, I'm looking to further optimize GPU memory usage. I have a few inquiries regarding the solution you provided:

I couldn't locate the file ha_dpo/trainer/LlavaDPOTrainer.py. I presume it might be either llava_dpo_trainer.py or base_dpo_trainer.py. Upon careful examination of the code, I noticed the latter initializes self.ref_model. If I were to utilize PEFT's get_base_model method, does this mean that more parameters are loaded into vmem? If so, what would be the appropriate way to load the reference model? im new starter with PEFT and transformers, can u provide examples for this?

I also considered using ha_dpo/models/llava-v1_5/llava/model/language_model/llava_llama.py and its LlavaLlamaForCausalLM.get_model to obtain the reference model through model.get_model. However, would this potentially utilize the same memory space as the model, leading to errors?

I've implemented ZeRO3. Does this imply that I'm already utilizing FSDP?

Your insights and guidance would be greatly appreciated. Thank you for your time and assistance.

The initialization of policy and reference model is in trainer/base_dpo_trainer.py:

if ref_model:
self.ref_model = ref_model
elif self.is_peft_model:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)

To only instantiate one model, you should pass ref_model=None.

And in reference model reward modeling:

all_logits = model.forward(
inputs_embeds=batch_inputs_embeds,
labels=None,
attention_mask=batch_attention_mask,
).logits.to(torch.float32)

you should pass a reference model, which can be obtained by with model.disable_adapters() and model.get_base_model(), such as :

with model.disable_adapters():
  all_logits = model.forward(
    inputs_embeds=batch_inputs_embeds,
    labels=None,
    attention_mask=batch_attention_mask,
  ).logits.to(torch.float32)

or

ref_model = model.get_base_model()
all_logits = model.forward(
    inputs_embeds=batch_inputs_embeds,
    labels=None,
    attention_mask=batch_attention_mask,
).logits.to(torch.float32)

Be sure to make sure that theref_model you get should not contain any lora parameters!

As for FSDP, there are no FSDP settings in the scrips provided in current zero3.json, so did you add FSDP settings? If not you maybe not using FSDP?

@JulioZhao97
Copy link
Collaborator

JulioZhao97 commented Mar 25, 2024

Furthermore, I encountered another issue. In the train_dpo.py file, I noticed that in the on_train_end method of the SaverCallback class, the condition for saving parameters is set as isinstance(kwargs['model'], PeftModelForCausalLM). However, during testing, I found that kwargs['model'] is always LlavaLlamaForCausalLM, which results in parameters not being saved after training completes. i have tried to replace PeftModelForCausalLM with LlavaLlamaForCausalLM and saved successfully

I assume that in your case, if you pass lora_enable=False, there are no lora parameters and model is no longer PeftModelForCausalLM but LlavaLlamaForCausalLM , that is why this phenomenon can happen.

@421zuoduan
Copy link
Author

421zuoduan commented Mar 27, 2024

Thank you for your assistance. I've made modifications to the code and successfully reduced GPU memory usage using the following command:

deepspeed --include localhost:0,1,2,3,4,5,6,7 ha_dpo/models/llava-v1_5/train_dpo_head.py \
    --lora_enable False \
    --deepspeed ha_dpo/models/llava-v1_5/scripts/zero3.json \
    --model_name_or_path model/llava-v1.5-7b \
    --version v1 \
    --vg_path ha_dpo/data/VG \
    --desc_data_path ha_dpo/data/hadpo/llava-v1.5/desc_data.json \
    --pope_data_path ha_dpo/data/hadpo/llava-v1.5/pope_data.json \
    --vision_tower model/clip-vit-large-patch14-336 \
    --freeze_backbone True \
    --tune_mm_mlp_adapter False \
    --tune_lm_head True \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ha_dpo/models/llava-v1_5/checkpoints/llava-origin \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-6 \
    --weight_decay 0. \
    --warmup_steps 0 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb \
    --run_name "llava-v1.5" \
    --beta 0.1

In llava_dpo_trainer.py, I made the following modifications:

# calculate logits
# all_logits = model.forward(
#     inputs_embeds=batch_inputs_embeds,
#     labels=None,
#     attention_mask=batch_attention_mask,
# ).logits.to(torch.float32)

if model is None:
    # get ref_model
    with model.disable_adapters():
        all_logits = model.forward(
            inputs_embeds=batch_inputs_embeds,
            labels=None,
            attention_mask=batch_attention_mask,
        ).logits.to(torch.float32)
else:
    # for policy model
    all_logits = model.forward(
        inputs_embeds=batch_inputs_embeds,
        labels=None,
        attention_mask=batch_attention_mask,
    ).logits.to(torch.float32)

In base_dpo_trainer.py, I made the following modifications:

# if ref_model:
#     self.ref_model = ref_model
# elif self.is_peft_model:
#     # The `model` with adapters turned off will be used as the reference model
#     self.ref_model = None
# else:
#     self.ref_model = create_reference_model(model)

if ref_model:
    self.ref_model = ref_model
else:
    self.ref_model = None

I've found some introductory blogs in Chinese about FSDP and ZeRO, blog1, blog2. I believe ZeRO3 should achieve similar effects to FSDP in reducing GPU memory usage.

Regarding some warnings during training, I'd like to provide my thoughts on them:

  1. Could not estimate the number of tokens of the input, floating-point operations will not be computed: According to a blog, I've chosen to disregard this warning directly.

  2. warnings.warn( HA-DPO/ha_dpo/trainer/llava_dpo_trainer_origin.py:135: UserWarning: compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator: because data_collator wasn't passed in train_dpo.py, self.use_dpo_data_collator = True was set in base_dpo_trainer.py, leading to this warning in llava_dpo_trainer.py. Since I haven't modified the related code or the input of data_collator, I've chosen to ignore this warning.

  3. [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented: According to a blog, this might be due to the DeepSpeed version. I've opted to ignore this warning.

  4. hadpo/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None: This warning appeared after modifying ref_model. I suspect it might be due to passing ref_model=None without gradients. I've chosen to ignore this warning.

I successfully trained on a 8x3090 server, and the results before and after modifying ref_model were consistent. After training head, the comparison results with Llava-1.5-7b are as follows:

Model method HA-DPO Accuracy Precision Recall F1 Score Yes Ratio (%)
LLaVA-1.5-7B popular × 86.23 83.28 90.67 86.82 54.43
LLaVA-1.5-7B popular 87.03 85.45 89.27 87.32 52.23
LLaVA-1.5-7B random × 89.67 88.89 90.67 89.77 51.00
LLaVA-1.5-7B random 90.10 90.78 89.27 90.02 49.17
LLaVA-1.5-7B adversarial × 79.73 74.40 90.67 81.73 60.93
LLaVA-1.5-7B adversarial 80.80 76.34 89.27 82.30 58.47

I would like to express my gratitude once again for your assistance !!

@JulioZhao97
Copy link
Collaborator

Thank you for your assistance. I've made modifications to the code and successfully reduced GPU memory usage using the following command:

deepspeed --include localhost:0,1,2,3,4,5,6,7 ha_dpo/models/llava-v1_5/train_dpo_head.py \
    --lora_enable False \
    --deepspeed ha_dpo/models/llava-v1_5/scripts/zero3.json \
    --model_name_or_path model/llava-v1.5-7b \
    --version v1 \
    --vg_path ha_dpo/data/VG \
    --desc_data_path ha_dpo/data/hadpo/llava-v1.5/desc_data.json \
    --pope_data_path ha_dpo/data/hadpo/llava-v1.5/pope_data.json \
    --vision_tower model/clip-vit-large-patch14-336 \
    --freeze_backbone True \
    --tune_mm_mlp_adapter False \
    --tune_lm_head True \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ha_dpo/models/llava-v1_5/checkpoints/llava-origin \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-6 \
    --weight_decay 0. \
    --warmup_steps 0 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb \
    --run_name "llava-v1.5" \
    --beta 0.1

In llava_dpo_trainer.py, I made the following modifications:

# calculate logits
# all_logits = model.forward(
#     inputs_embeds=batch_inputs_embeds,
#     labels=None,
#     attention_mask=batch_attention_mask,
# ).logits.to(torch.float32)

if model is None:
    # get ref_model
    with model.disable_adapters():
        all_logits = model.forward(
            inputs_embeds=batch_inputs_embeds,
            labels=None,
            attention_mask=batch_attention_mask,
        ).logits.to(torch.float32)
else:
    # for policy model
    all_logits = model.forward(
        inputs_embeds=batch_inputs_embeds,
        labels=None,
        attention_mask=batch_attention_mask,
    ).logits.to(torch.float32)

In base_dpo_trainer.py, I made the following modifications:

# if ref_model:
#     self.ref_model = ref_model
# elif self.is_peft_model:
#     # The `model` with adapters turned off will be used as the reference model
#     self.ref_model = None
# else:
#     self.ref_model = create_reference_model(model)

if ref_model:
    self.ref_model = ref_model
else:
    self.ref_model = None

I've found some introductory blogs in Chinese about FSDP and ZeRO, blog1, blog2. I believe ZeRO3 should achieve similar effects to FSDP in reducing GPU memory usage.

Regarding some warnings during training, I'd like to provide my thoughts on them:

  1. Could not estimate the number of tokens of the input, floating-point operations will not be computed: According to a blog, I've chosen to disregard this warning directly.
  2. warnings.warn( HA-DPO/ha_dpo/trainer/llava_dpo_trainer_origin.py:135: UserWarning: compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator: because data_collator wasn't passed in train_dpo.py, self.use_dpo_data_collator = True was set in base_dpo_trainer.py, leading to this warning in llava_dpo_trainer.py. Since I haven't modified the related code or the input of data_collator, I've chosen to ignore this warning.
  3. [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented: According to a blog, this might be due to the DeepSpeed version. I've opted to ignore this warning.
  4. hadpo/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None: This warning appeared after modifying ref_model. I suspect it might be due to passing ref_model=None without gradients. I've chosen to ignore this warning.

I successfully trained on a 3x3090 server, and the results before and after modifying ref_model were consistent. After training head, the comparison results with Llava-1.5-7b are as follows:

Model method HA-DPO Accuracy Precision Recall F1 Score Yes Ratio (%)
LLaVA-1.5-7B popular × 86.23 83.28 90.67 86.82 54.43
LLaVA-1.5-7B popular √ 87.03 85.45 89.27 87.32 52.23
LLaVA-1.5-7B random × 89.67 88.89 90.67 89.77 51.00
LLaVA-1.5-7B random √ 90.10 90.78 89.27 90.02 49.17
LLaVA-1.5-7B adversarial × 79.73 74.40 90.67 81.73 60.93
LLaVA-1.5-7B adversarial √ 80.80 76.34 89.27 82.30 58.47
I would like to express my gratitude once again for your assistance !!

Happy to see your contribution to our codebase and congratulate on the satisfying results!

hadpo/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None this warning should be carefully checked furthermore, maybe there are things should be modified during reference reward modeling?

Once again thanks for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants