Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate FSDP always removed {'model.norm.weight'} layer of model when saving them #2155

Closed
2 of 4 tasks
gauss5930 opened this issue Nov 15, 2023 · 11 comments
Closed
2 of 4 tasks

Comments

@gauss5930
Copy link

gauss5930 commented Nov 15, 2023

System Info

Accelerate: 0.24.1
OS: Linux-5.4.0-148-generic-x86_64-with-glibc2.35
Python version: 3.10.12
Numpy version: 1.24.1
PyTorch: 2.1.0+cu118

Accelerate Configuration

compute_environment: LOCAL_MACHINE                                                                                             
debug: false                                                                                                                   
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 0
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Run my own fine-tuning code with the Accelerate FSDP config mentioned above on 2 * A100 80G GPUs. I also used use_flash_attention_2=True and gradient_checkpointing=True. The following command and code were used for fine-tuning. Actually, I did set the epoch to 3 and max_step to 'max', but I changed the value of the hyperparameter to check the error message early.

!accelerate launch --config_file=accelerate_configs/fsdp_config.yaml --num_processes=2 finetuning/finetune.py \
    --model_path beomi/llama-2-koen-13b \
    --data_path Cartinoe5930/KoRAE_filtered_12k \
    --output_dir finetuning/result/llama2/ \
    --wandb_project example_1 \
    --wandb_run_name example_1 \
    --hf_hub_path HUGGINGFACE_HUB_PATH_TO_UPLOAD \
    --hf_token MY_HUGGINGFACE_TOKEN \
    --num_epochs 1 
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling
from accelerate import Accelerator
from datasets import load_dataset, Dataset
import huggingface_hub

from trl import SFTTrainer
from accelerate import Accelerator

from utils.prompter import Prompter

import argparse

def args_parse():
    parser = argparse.ArgumentParser()

    parser.add_argument("--hf_token", type=str, help="Required to upload models to hub.")
    parser.add_argument("--model_path", type=str, default="beomi/llama-2-koen-13b")
    parser.add_argument("--data_path", type=str, default="Cartinoe5930/KoRAE_filtered_12k")

    parser.add_argument("--seq_length", type=int, default=4096)
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--micro_batch_size", type=int, default=4)
    parser.add_argument("--val_set_size", type=float, default=0)
    parser.add_argument("--logging_steps", type=int, default=1)
    parser.add_argument("--save_strategy", type=str, default="epoch", help="You can choose the strategy of saving model.")
    parser.add_argument("--gradient_checkpointing", type=bool, default=True)
    parser.add_argument("--group_by_length", type=bool, default=False)
    parser.add_argument("--packing", type=bool, default=False)

    parser.add_argument("--learning_rate", type=float, default=1e-5)
    parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
    parser.add_argument("--warmup_ratio", type=float, default=0.03)
    parser.add_argument("--weight_decay", type=float, default=0)
    
    parser.add_argument("--wandb_project", type=str)
    parser.add_argument("--wandb_run_name", type=str)
    parser.add_argument("--num_workers", type=int, required=True)

    parser.add_argument(
        "--output_dir",
        type=str,
        required=True
    )
    parser.add_argument(
        "--hf_hub_path",
        type=str,
        required=True,
        help="The hub path to upload the model"
    )

    return parser.parse_args()

def process_dataset(dataset):
    prompter = Prompter("KoRAE_template")

    list_data = dataset.to_list()
    
    for data in list_data:
        data["prompted_input"] = prompter.generate_prompt(
            data["instruction"],
            data["prompt"],
            data["input"],
            data["output"])

    result_data = Dataset.from_list(list_data)

    return result_data

def create_datasets(args):
    dataset = load_dataset(
        args.data_path,
        split="train",
        num_proc=args.num_workes
    )

    if args.val_set_size > 0:
        train_val = dataset.train_test_split(test_size=args.val_set_size, seed=42)

        train_data = train_val["train"]
        val_data = train_val["test"]
    else:
        train_data = dataset
        val_data = None

    return train_data, val_data


if __name__ == "__main__":
    args = args_parse()

    huggingface_hub.login(args.hf_token)

    gradient_accumulation_steps = args.batch_size // args.micro_batch_size

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        device_map={"": Accelerator().process_index},
        torch_dtype=torch.bfloat16,
        use_auth_token=args.hf_token,
        use_flash_attention_2=True
    )
    model.config.use_cache = False
    model.enable_input_require_grads()

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        use_auth_token=args.hf_token,
    )

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Check if parameter passed or if set within environ
    use_wandb = len(args.wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(args.wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = args.wandb_project

    train_dataset, eval_dataset = create_datasets(args)

    train_dataset = process_dataset(train_dataset)
    eval_dataset = process_dataset(eval_dataset) if eval_dataset else None
    
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.micro_batch_size,
        per_device_eval_batch_size=args.micro_batch_size if eval_dataset else None,
        gradient_accumulation_steps=gradient_accumulation_steps,
        gradient_checkpointing=args.gradient_checkpointing,
        learning_rate=args.learning_rate,
        logging_steps=args.logging_steps,
        save_strategy=args.save_strategy,
        save_steps=args.save_steps if args.save_strategy == "steps" else None,
        evaluation_strategy="epoch" if eval_dataset else "no",
        group_by_length=args.group_by_length,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_ratio=args.warmup_ratio,
        bf16=True,
        save_total_limit=2,
        remove_unused_columns=False,
        report_to="wandb" if use_wandb else None,
        run_name=args.wandb_run_name if use_wandb else None,
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        dataset_text_field="prompted_input",
        data_collator=data_collator,
        packing=args.packing,
        max_seq_length=args.seq_length,
        tokenizer=tokenizer,
        args=training_args
    )

    trainer.train()

    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    trainer.model.push_to_hub(args.hf_hub_path)
    trainer.tokenizer.push_to_hub(args.hf_hub_path)

    trainer.save_model(args.output_dir)

Expected behavior

The training was conducted very well, but the problem occurred when saving and uploading the model to the HuggingFace hub. I saw that the code execution log showed the message Removed shared tensor {'model.norm.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading before starting to upload the model.

... training message

Removed shared tensor {'model.norm.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

... model uploading

Although that message made me feel disturbed, I just moved on. However, I met a huge problem when loading my fine-tuned model. The error message is as follows.

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("MY_UPLOAD_MODEL_PATH")
model = AutoModelForCausalLM.from_pretrained("MY_UPLOAD_MODEL_PATH")
RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
	size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([46336, 5120]).
	size mismatch for model.norm.weight: copying a param with shape torch.Size([2560]) from checkpoint, the shape in current model is torch.Size([5120]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

I took a lot of time to solve this problem by googling or asking a question at the HuggingFace Forum, however, I was not able to find any solution or something that helped me to solve this obstacle. Please let me know how to solve this problem! I really want to save the model completely.

@muellerzr
Copy link
Collaborator

cc @pacman100

@pacman100
Copy link
Contributor

Hello @gauss5930, what are the contents of the final output directory?

@gauss5930
Copy link
Author

Hello @pacman100 !
The model.safetensors and other tokenizer configurations were saved in the final output directory. However, my saved model's model.safetensors.index.json file shows that there is no model.norm.weight layer. Actually, there are gaps between the base model size and mine.

Unfortunately, I don't have any figures to show you how my model was uploaded. Would this have provided enough information?If you have any further questions, please feel free to ask! Thank you.

@gauss5930
Copy link
Author

@pacman100 !

There is something I just discovered additionally. When I checked model.safetensors.index.json, the shape of lm_head.weight of the uploaded model was [0], and the shape of model.embed_tokens.weight was the vocab_size of the model. We were able to confirm that it was [237,240,320], which is the product of 46,336 and hidden_size of 5120. Other than this, everything was the same with no differences other than the disappearance of the model.norm.weight part.

I hope this information helps!

@DableUTeeF
Copy link

Some things to add is that this _tied_weights_keys was previously called _keys_to_ignore_on_load_missing.

Could it be that Accelerate simply remove the weight based on just that? As I simply could not comprehend this function https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/other.py#L147

@jd445
Copy link

jd445 commented Dec 29, 2023

I also meet this problem and here is my config

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: false
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: 5
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: false
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@pranaydeeps
Copy link

Also running into a similar issue with the run_mlm_no_trainer example.

21 01/02/2024 17:47:58 - WARNING - accelerate.utils.other - Removed shared tensor {'lm_head.decoder.weight', 'lm_head.decoder.bias'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

Then running into errors when trying to use accelerator.load_state

@muellerzr
Copy link
Collaborator

muellerzr commented Jan 5, 2024

@pranaydeeps probably related to huggingface/transformers#27293 & huggingface/transformers#27972 for your part there

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Feb 6, 2024
@cameronfr
Copy link

cameronfr commented Feb 21, 2024

Had a similar issue with a GRU in my model, where I was getting Missing key(s) in state_dict: "ref_enc.gru.weight_hh_l0", "ref_enc.gru.bias_ih_l0", "ref_enc.gru.bias_hh_l0". after trying to load a checkpoint created with accelerator.save_state("checkpoint"). Looking at the checkpoint, only one of the three GRU weights were in it. As mentioned in the other issues, accelerator.save_state("checkpoint", safe_serialization=False) seems to fix it.

@hungphongtrn
Copy link

hungphongtrn commented Jul 19, 2024

@gauss5930, just want to ask did you solve your issue? Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants