Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdaLoRA: where does the Rankallocator work #432

Closed
louislau1129 opened this issue May 11, 2023 · 8 comments
Closed

AdaLoRA: where does the Rankallocator work #432

louislau1129 opened this issue May 11, 2023 · 8 comments

Comments

@louislau1129
Copy link

Hi, I have used adalora to do parameter-efficient fine-tuning for a while. It works very well and significantly better than the vanilla lora in my case. But recently I just found two issues.

  1. adalora did not use the orthogonal regularization since its forward function in https://github.com/huggingface/peft/blob/main/src/peft/tuners/adalora.py#L216 can't be called. This happened when I used PeftModel to call adalora as below:
config = AdaLoraConfig(target_r=args.rank, init_r=init_r, lora_alpha=64, target_modules=args.target_modules, lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)

I found the reason is that the self.get_base_model() in forward() function of PeftModel did not point to the AdaLoraModel as base model. (https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L300)

The original code is as follows:

    def get_base_model(self):
        """
        Returns the base model.
        """
        return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model

After modifying it as the following, it can use the AdaLoraModel forward function as expected.

    def get_base_model(self):
        """
        Returns the base model.
        """
        return self.base_model if isinstance(self.active_peft_config, (PromptLearningConfig, AdaLoraConfig)) else self.base_mode.model
  1. The second issue is that I cannot find where the Rankallocator in adalora is called. If I understand correctly, this function is implemented in def update_and_allocate(self, global_step): in https://github.com/huggingface/peft/blob/main/src/peft/tuners/adalora.py#L284. However, when I add a breakpoint here, the program will not stop. Maybe for this reason, I cannot find those masked rank (elements on some rank positions = 0) in lora_E of the peft saved model.

Anyone has idea on this issue? I really appreciate any help you can provide.

@louislau1129
Copy link
Author

I have double-checked the second issue, rankallocator indeed did not call in the current peft version. I manually call it to perform the adaptive rank allocation.

@moritzunseld
Copy link

Thanks for pointing this out. I've also encountered some weird behavior when benchmarking for my Bachelor's Thesis.

@moritzunseld
Copy link

I have double-checked the second issue, rankallocator indeed did not call in the current peft version. I manually call it to perform the adaptive rank allocation.

Where/how do you manually call it?

@louislau1129
Copy link
Author

louislau1129 commented May 12, 2023

Where/how do you manually call it?

I add the following code in https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/trainer.py#L2013 to explicitly call this allocate function .

   from peft import PeftModel
   if isinstance(model, PeftModel):
        if getattr(model.base_model, "update_and_allocate", None) is not None:
               model.base_model.update_and_allocate(total_batched_samples)

Besides this, you should also set corresponding tinit, tfinal, deltaT, and total_steps in AdaLoraConfig.
I am not sure if there is a more elegant way to do that, but it works.
Also I did not find too much difference in terms of fine-tuning performance after fixing these two issues using my way. I will further investigate them.
Hope the developer could give some ideas/comments about these issues.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@PluszZh
Copy link

PluszZh commented Oct 12, 2023

Where/how do you manually call it?

I add the following code in https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/trainer.py#L2013 to explicitly call this allocate function .

   from peft import PeftModel
   if isinstance(model, PeftModel):
        if getattr(model.base_model, "update_and_allocate", None) is not None:
               model.base_model.update_and_allocate(total_batched_samples)

Besides this, you should also set corresponding tinit, tfinal, deltaT, and total_steps in AdaLoraConfig. I am not sure if there is a more elegant way to do that, but it works. Also I did not find too much difference in terms of fine-tuning performance after fixing these two issues using my way. I will further investigate them. Hope the developer could give some ideas/comments about these issues.

I found that this issue still exists in the current version. Can you share your code? Thank you very much!

@geoffvdr
Copy link

This issue still exists. Any suggestion on how to make rankallocator / update_and_allocate work with peft?

@QingruZhang maybe?

@BenjaminBossan
Copy link
Member

Note that PEFT does not contain training code, as such calling update_and_allocate is out of scope for PEFT. When using Trainer, this could maybe be solved with a callback, but I have only little experience with Trainer. When running a custom training loop, call this method manually, as e.g. shown in this AdaLoRA training script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants