Skip to content

为什么 yaml中设置的lora_rank和adapter config文件的lora_rank有2倍的关系,以哪个为准? #6672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 task done
chuangzhidan opened this issue Jan 16, 2025 · 2 comments
Labels
duplicate This issue or pull request already exists

Comments

@chuangzhidan
Copy link

chuangzhidan commented Jan 16, 2025

Reminder

  • I have read the above rules and searched the existing issues.

System Info

(base) ubuntu@localhost:~$ pip list
Package Version


accelerate 1.1.1
adam_mini 1.1.1

aiofiles 23.2.1
aiohappyeyeballs 2.4.0
aiohttp 3.10.5
aiosignal 1.3.1
anaconda-anon-usage 0.4.4
annotated-types 0.7.0
anyio 4.4.0
archspec 0.2.3
asgiref 3.8.1
attrs 24.2.0
auto_round 0.4.2
backoff 2.2.1
bcrypt 4.2.0
beautifulsoup4 4.12.3
blobfile 3.0.0
boltons 23.0.0
Brotli 1.0.9
build 1.2.2.post1
cachetools 5.5.0
cbor 1.0.0
certifi 2024.8.30
cffi 1.16.0
charset-normalizer 3.3.2
chroma-hnswlib 0.7.6
chromadb 0.5.13
click 8.1.7
click-plugins 1.1.1
cligj 0.7.2
cloudpickle 3.0.0
cn2an 0.5.22
coloredlogs 15.0.1
conda 24.7.1
conda-content-trust 0.2.0
conda-libmamba-solver 24.7.0
conda-package-handling 2.3.0
conda_package_streaming 0.10.0
cryptography 42.0.5
dashscope 1.20.14
datasets 3.2.0
deepspeed 0.16.2
Deprecated 1.2.14
dill 0.3.8
diskcache 5.6.3
distro 1.9.0
durationpy 0.9
einops 0.8.0
et_xmlfile 2.0.0
fastapi 0.115.6
ffmpy 0.5.0
filelock 3.15.4
fiona 1.10.1
fire 0.7.0
FlagEmbedding 1.3.2
flatbuffers 24.3.25
frozendict 2.4.2
frozenlist 1.4.1
fsspec 2024.3.1
geopandas 0.14.4
gguf 0.9.1
google-auth 2.35.0
googleapis-common-protos 1.65.0
gradio 5.10.0
gradio_client 1.5.3
grpcio 1.66.2
h11 0.14.0
hjson 3.1.0
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.2
huggingface-hub 0.27.1
humanfriendly 10.0
idna 3.8
ijson 3.3.0
importlib_metadata 8.4.0
importlib_resources 6.4.5
inscriptis 2.5.0
interegular 0.3.3
ir_datasets 0.5.8
jieba 0.42.1
Jinja2 3.1.4
jiter 0.5.0
joblib 1.4.2
jsonpatch 1.33
jsonpointer 2.1
jsonschema 4.23.0
jsonschema-specifications 2023.12.1
kubernetes 31.0.0
lark 1.2.2
libmambapy 1.5.8
liger_kernel 0.5.2
llama_models 0.0.63
llama_stack 0.0.63
llama_stack_client 0.0.63
llvmlite 0.43.0
lm-format-enforcer 0.10.6
lxml 5.3.0
lz4 4.3.3
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
menuinst 2.1.2
mistral_common 1.3.4
mmh3 5.0.1
modelscope 1.17.1
monotonic 1.6
mpmath 1.3.0
msgpack 1.0.8
msgspec 0.18.6
multidict 6.0.5
multiprocess 0.70.16
nats-py 2.9.0
nest-asyncio 1.6.0
networkx 3.3
ninja 1.11.1.3
niuload 0.2.4
numba 0.60.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.560.30
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.68
nvidia-nvtx-cu12 12.1.105
nvitop 1.3.2
oauthlib 3.2.2
onnx 1.17.0
onnxruntime 1.19.2
openai 1.43.1
openpyxl 3.1.5
opentelemetry-api 1.27.0
opentelemetry-exporter-otlp-proto-common 1.27.0
opentelemetry-exporter-otlp-proto-grpc 1.27.0
opentelemetry-instrumentation 0.48b0
opentelemetry-instrumentation-asgi 0.48b0
opentelemetry-instrumentation-fastapi 0.48b0
opentelemetry-proto 1.27.0
opentelemetry-sdk 1.27.0
opentelemetry-semantic-conventions 0.48b0
opentelemetry-util-http 0.48b0
optimum 1.22.0
orjson 3.10.7
osmnx 1.9.4
outlines 0.0.46
overrides 7.7.0
packaging 24.1
pandas 2.2.2
partial-json-parser 0.2.1.1.post4
peft 0.13.2
pillow 10.4.0
pip 24.2
platformdirs 3.10.0
pluggy 1.0.0
posthog 3.7.0
proces 0.1.7
prometheus_client 0.20.0
prometheus-fastapi-instrumentator 7.0.0
prompt_toolkit 3.0.48
propcache 0.2.1
protobuf 4.25.5
psutil 6.0.0
py-cpuinfo 9.0.0
pyairports 2.1.1
pyaml 25.1.0
pyarrow 17.0.0
pyarrow-hotfix 0.6
pyasn1 0.6.1
pyasn1_modules 0.4.1
pycosat 0.6.6
pycountry 24.6.1
pycparser 2.21
pycryptodomex 3.21.0
pydantic 2.9.0
pydantic_core 2.23.2
pydub 0.25.1
Pygments 2.18.0
PyPika 0.48.9
pyproj 3.7.0
pyproject_hooks 1.2.0
PySocks 1.7.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.20
pytz 2024.1
PyYAML 6.0.2
pyzmq 26.2.0
ray 2.35.0
referencing 0.35.1
regex 2024.7.24
requests 2.32.3
requests-oauthlib 2.0.0
rich 13.9.2
rpds-py 0.20.0
rsa 4.9
ruamel.yaml 0.17.21
ruff 0.8.6
safehttpx 0.1.6
safetensors 0.4.5
sageattention 1.0.6
scikit-learn 1.5.2
scipy 1.14.1
semantic-version 2.10.0
sentence-transformers 3.2.0
sentencepiece 0.2.0
setuptools 74.1.2
shapely 2.0.6
shellingham 1.5.4
six 1.16.0
sniffio 1.3.1
soupsieve 2.6
sse-starlette 2.1.3
starlette 0.41.3
sympy 1.13.2
tenacity 9.0.0
termcolor 2.5.0
threadpoolctl 3.5.0
tiktoken 0.7.0
tokenizers 0.20.3
tomlkit 0.13.2
torch 2.4.0
torchvision 0.19.0
tqdm 4.66.5
transformers 4.46.0
trec-car-tools 2.6
triton 3.0.0
trl 0.13.0
unlzw3 0.2.2
urllib3 2.2.2
uvicorn 0.30.6
uvloop 0.20.0
vllm 0.6.0
vllm-flash-attn 2.6.1

(base) ubuntu@localhost:$ pip show torch
Name: torch
Version: 2.4.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /home/ubuntu/.local/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, setuptools, sympy, triton, typing-extensions
Required-by: accelerate, auto_round, deepspeed, FlagEmbedding, liger_kernel, niuload, optimum, peft, sentence-transformers, torchvision, vllm, vllm-flash-attn, xformers
(base) ubuntu@localhost:
$ python --version
Python 3.12.4

Reproduction

examples/extras/pissa/llama3_lora_sft.yaml中设置了lora_rank: 16,checkpoint下是16,但上级目录看到pissa adpter config 的lora rank 为32,不设置时,默认rank为8对吧,但是里面的为16.
不知道为什么会有两倍的关系,到底哪个是真实的rank值?

Others

0c136f2512e59fa0a2335e679b168f83
5aed60073dbfacee0b84faacfa957418

@chuangzhidan chuangzhidan added bug Something isn't working pending This problem is yet to be addressed labels Jan 16, 2025
@chuangzhidan chuangzhidan reopened this Jan 16, 2025
@hiyouga
Copy link
Owner

hiyouga commented Jan 17, 2025

Pissa automatically doubles the lora rank

@hiyouga hiyouga closed this as completed Jan 17, 2025
@hiyouga hiyouga added duplicate This issue or pull request already exists and removed bug Something isn't working pending This problem is yet to be addressed labels Jan 17, 2025
@chuangzhidan
Copy link
Author

chuangzhidan commented Jan 17, 2025

Pissa automatically doubles the lora rank

ok ,so when i set lora_rank:16 in yaml file ,i am actually tranining a model of 32 ? this is a surprise

so far , pissa training is effective when i load and merge the adapter in typical lora style under pissa_converted folder, despite found merged model is slightly different to original model.

def compare_model_weights(model1, model2):
"""
Compare the weights of two models and return True as soon as any layer's weights are different (early exit).
Return False if all weights are the same.
"""
for name1, param1 in model1.named_parameters():
if name1 in model2.state_dict():
param2 = model2.state_dict()[name1]
# Early exit if any weights are different
if not torch.allclose(param1, param2):
print(f"Layer '{name1}': Weights are DIFFERENT.")
return True
else:
print(f"Layer '{name1}' not found in the second model.")
return True

# Return False if no differences were found
return False

output:
Layer 'model.layers.0.self_attn.q_proj.base_layer.weight' not found in the second model.
Merging is valid.

merged_model:Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 5120)
(layers): ModuleList(
(0-47): 48 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=5120, out_features=5120, bias=True)
(k_proj): Linear(in_features=5120, out_features=1024, bias=True)
(v_proj): Linear(in_features=5120, out_features=1024, bias=True)
(o_proj): Linear(in_features=5120, out_features=5120, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
(up_proj): Linear(in_features=5120, out_features=13824, bias=False)
(down_proj): Linear(in_features=13824, out_features=5120, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((5120,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((5120,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((5120,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=5120, out_features=152064, bias=False)
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

2 participants