Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

可能可以简化代码? #6

Open
WuNein opened this issue Nov 14, 2023 · 8 comments
Open

可能可以简化代码? #6

WuNein opened this issue Nov 14, 2023 · 8 comments

Comments

@WuNein
Copy link

WuNein commented Nov 14, 2023

此处

orthogonal_loss = 0.

由于这边是跟没有梯度的lora(old)来计算正交,那直接在上一步把lora(old)save为pth是不是可以避免修改peft库了

import torch

# 假设 self.model 是你的模型
stacked_params = {}

for name, param in self.model.named_parameters():
    if "lora_" in name:
        stacked_params[name] = param.data.clone()  # 使用 clone() 复制参数并避免共享内存

# 保存堆叠的参数到文件
torch.save(stacked_params, "path/to/stacked_params.pth")

然后在trainer类里面加载

# 初始化一个字典来存储匹配的模块和对应的参数
matched_modules = {} #load pth

for name, param in self.model.named_parameters():
    if "lora_A" in name:
          # 匹配的模块名称和对应的参数
          param_ = matched_modules[name]

          orthogonal_loss += torch.abs(torch.mm(param, param_.T)).sum()  # [r * dim] * [dim * r]
          break  # target modules have been matched

大致这个意思

是不是就可以避免修改PEFT代码,方便很多?

@cmnfriend
Copy link
Owner

可以的!👍

@WuNein
Copy link
Author

WuNein commented Nov 14, 2023

哦对,有个问题我不懂就问了:)懒得再翻您改的PEFT代码了(不是
既然说是当前LoRA在之前LoRA的正交方向上更新的;那么当前的LoRA大概率是merge之前LoRA,以此为基础继续训练的吧?我没理解错吧

@DumoeDss
Copy link

哦对,有个问题我不懂就问了:)懒得再翻您改的PEFT代码了(不是 既然说是当前LoRA在之前LoRA的正交方向上更新的;那么当前的LoRA大概率是merge之前LoRA,以此为基础继续训练的吧?我没理解错吧

训练完会进行merge
#5 (comment)

@WuNein
Copy link
Author

WuNein commented Nov 17, 2023

哦对,有个问题我不懂就问了:)懒得再翻您改的PEFT代码了(不是 既然说是当前LoRA在之前LoRA的正交方向上更新的;那么当前的LoRA大概率是merge之前LoRA,以此为基础继续训练的吧?我没理解错吧

训练完会进行merge #5 (comment)

我的疑惑在新的task的lora初始化上面,既然说是最后合并的,我姑且认为是随机初始化的~毕竟代码上loss要保证两个lora_a是正交的。

@DumoeDss
Copy link

话说照着你这样修改的话,原本的l2_loss就没有了吗?
最终的loss = loss + orthogonal_loss * lamda_1吗?

@WuNein
Copy link
Author

WuNein commented Nov 18, 2023

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

你自己加上就好了,又不冲突…… 只是我懒得写了

@DumoeDss
Copy link

DumoeDss commented Nov 18, 2023

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

你自己加上就好了,又不冲突…… 只是我懒得写了

是直接用matched_modules进行计算吗?

l2_loss = 0.
        for name, param in matched_modules:
            l2_loss += torch.norm(param, p=2)

@WuNein
Copy link
Author

WuNein commented Nov 19, 2023

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

你自己加上就好了,又不冲突…… 只是我懒得写了

是直接用matched_modules进行计算吗?

l2_loss = 0.
        for name, param in matched_modules:
            l2_loss += torch.norm(param, p=2)

完全不对吧,

# l2-normalization for loranew_A/B
        l2_loss = 0.
        for name, param in self.model.named_parameters():
            if "loranew_" in name:
                l2_loss += torch.norm(param, p=2)

原本代码里面写的是新的loranew,那么简化代码以后目标是

# l2-normalization for loranew_A/B
        l2_loss = 0.
        for name, param in self.model.named_parameters():
            if "lora_" in name:
                l2_loss += torch.norm(param, p=2)

lora_ 就是原本的lora_new啊,l2正则肯定是对现在task的参数进行的啊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants