Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lora PISSA init: not support gpt2 #2103

Closed
2 of 4 tasks
suyang160 opened this issue Sep 26, 2024 · 4 comments
Closed
2 of 4 tasks

Lora PISSA init: not support gpt2 #2103

suyang160 opened this issue Sep 26, 2024 · 4 comments

Comments

@suyang160
Copy link
Contributor

System Info

peft 0.13.0
transformers 4.44.2
torch 2.4.0
Python 3.12.4

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

import os
os.environ["WANDB_DISABLED"] = "true"
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset
from torchsummary import summary
import torch
from datasets import load_dataset, config
from trl import SFTTrainer

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0,
target_modules=["attn.c_attn"],
init_lora_weights="pissa",
fan_in_fan_out=True,
bias="none"
)

model = get_peft_model(model, lora_config)

dataset = load_dataset("imdb", split="train[:1%]")

trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=128,
tokenizer=tokenizer,
)

Expected behavior

Hello, I found that current pissa init code forget to consider the fin_in_fin_out parameter to transpose the matrix weight, which makes gpt2 training failed because of dimension mismatch. I have fixed the bug with the following code:

def pissa_init(self, adapter_name, init_lora_weights):
    weight = self.get_base_layer().weight
    dtype = weight.dtype
    if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
        raise TypeError(
            "Please initialize PiSSA under float32, float16, or bfloat16. "
            "Subsequently, re-quantize the residual model to help minimize quantization errors."
        )
    weight = transpose(weight.to(torch.float32),self.fan_in_fan_out)
    if init_lora_weights == "pissa":
        # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
        V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
        Vr = V[:, : self.r[adapter_name]]
        Sr = S[: self.r[adapter_name]]
        Sr /= self.scaling[adapter_name]
        Uhr = Uh[: self.r[adapter_name]]
    elif len(init_lora_weights.split("_niter_")) == 2:
        Vr, Sr, Ur = svd_lowrank(
            weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
        )
        Sr /= self.scaling[adapter_name]
        Uhr = Ur.t()
    else:
        raise ValueError(
            f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
        )

    lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
    lora_B = Vr @ torch.diag(torch.sqrt(Sr))
    self.lora_A[adapter_name].weight.data = lora_A
    self.lora_B[adapter_name].weight.data = lora_B
    weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
    weight = transpose(weight.to(dtype),self.fan_in_fan_out)
    self.get_base_layer().weight.data = weight
@BenjaminBossan
Copy link
Member

Thanks for reporting this bug and providing a potential solution. Would you be interested in creating a PR with your fix?

@suyang160
Copy link
Contributor Author

Thanks for reporting this bug and providing a potential solution. Would you be interested in creating a PR with your fix?

Thanks! I'd be happy to submit a PR with my fix.

suyang160 pushed a commit to suyang160/peft that referenced this issue Sep 26, 2024
…SA initialization (huggingface#2103)

Previously, the weight matrix was converted to float32 without considering the need for transposition. This update ensures that the weight matrix is transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.
suyang160 pushed a commit to suyang160/peft that referenced this issue Oct 8, 2024
…SA initialization (huggingface#2103)

This update addresses an issue where the weight matrix was converted to float32 without considering the need for transposition. The weight matrix is now transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.

To ensure this fix is robust, tests have been updated to include parameterized cases for different devices and bit configurations. Additionally, the isinstance checks have been modified to include Conv1D layers, ensuring all relevant layers are processed correctly.
suyang160 pushed a commit to suyang160/peft that referenced this issue Oct 8, 2024
…SA initialization (huggingface#2103)

This update addresses an issue where the weight matrix was converted to float32 without considering the need for transposition. The weight matrix is now transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.

To ensure this fix is robust, tests have been updated to include parameterized cases for different devices and bit configurations. Additionally, the isinstance checks have been modified to include Conv1D layers, ensuring all relevant layers are processed correctly.
BenjaminBossan pushed a commit that referenced this issue Oct 8, 2024
Transpose weight matrix based on fan_in_fan_out condition in PiSSA
initialization.

Co-authored-by: Yang Su <suyang360@gmail.com>
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this issue Oct 22, 2024
…ce#2104)

Transpose weight matrix based on fan_in_fan_out condition in PiSSA
initialization.

Co-authored-by: Yang Su <suyang360@gmail.com>
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

This issue is resolved via #2104.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants