Skip to content

请教下, 调用nunchaku.merge_safetensors的api后, nunchaku加载量化后的模型报错了:assert state_dict[k].dtype == model_state_dict[k].dtype #5

@moveforever

Description

@moveforever

转换成nunchaku格式

python -m deepcompressor.backend.nunchaku.convert --quant-path qwen-image1 --output-root qwen-image-svdq --model-name qwenimage

调用nunchaku的api合并成一个文件

cp co*json qwen-image-svdq/qwenimage/
python -m nunchaku.merge_safetensors -i qwen-image-svdq/qwenimage -o svdq-int4_r32-qwen-image-edit-2511-lightningv2.0-4steps.safetensors

具体报错点

key transformer_blocks.0.attn.to_qkv.wscales is not in model_state_dict, state_dict[k].dtype = torch.float8_e4m3fn, model_state_dict[k].dtype = torch.bfloat16
assert state_dict[k].dtype == model_state_dict[k].dtype

量化的命令如下:

conf_file=examples/diffusion/configs/model/qwenimage.yaml
python -m deepcompressor.app.diffusion.ptq \
        ${conf_file} \
        examples/diffusion/configs/svdquant/int4.yaml

详细日志如下:

metadata: {'comfy_config': '{\n  "model_class": "QwenImageTransformer",\n  "model_config": {\n    "axes_dim": [\n      16,\n      56,\n      56\n    ],\n    "context_in_dim": 3584,\n    "depth": 60,\n    "disable_unet_model_creation": true,\n    "guidance_embed": false,\n    "hidden_size": 3072,\n    "image_model": "qwen",\n    "in_channels": 16,\n    "mlp_ratio": 4.0,\n    "num_heads": 24,\n    "out_channels": 16,\n    "patch_size": 2,\n    "qkv_bias": true,\n    "theta": 10000\n  }\n}', 'quantization_config': '{"method": "svdquant", "weight": {"dtype": "int4", "scale_dtype": null, "group_size": 64}, "activation": {"dtype": "int4", "scale_dtype": null, "group_size": 64}, "rank": 32}', 'config': '{\n  "_class_name": "QwenImageTransformer2DModel",\n  "_diffusers_version": "0.35.0.dev0",\n  "attention_head_dim": 128,\n  "axes_dims_rope": [\n    16,\n    56,\n    56\n  ],\n  "guidance_embeds": false,\n  "in_channels": 64,\n  "joint_attention_dim": 3584,\n  "num_attention_heads": 24,\n  "num_layers": 60,\n  "out_channels": 16,\n  "patch_size": 2\n}', 'model_class': 'NunchakuFluxTransformer2dModel'}
key time_text_embed.timestep_embedder.linear_1.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key time_text_embed.timestep_embedder.linear_1.bias is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key time_text_embed.timestep_embedder.linear_2.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key time_text_embed.timestep_embedder.linear_2.bias is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key txt_norm.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key img_in.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key img_in.bias is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key txt_in.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key txt_in.bias is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.img_mod.1.qweight is in model_state_dict, state_dict[k].dtype = torch.int32, model_state_dict[k].dtype = torch.int32
key transformer_blocks.0.img_mod.1.bias is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.img_mod.1.wscales is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.img_mod.1.wzeros is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.attn.norm_q.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.attn.norm_k.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.attn.norm_added_q.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.attn.norm_added_k.weight is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.attn.to_qkv.qweight is in model_state_dict, state_dict[k].dtype = torch.int8, model_state_dict[k].dtype = torch.int8
key transformer_blocks.0.attn.to_qkv.bias is in model_state_dict, state_dict[k].dtype = torch.bfloat16, model_state_dict[k].dtype = torch.bfloat16
key transformer_blocks.0.attn.to_qkv.wscales is in model_state_dict, state_dict[k].dtype = torch.float8_e4m3fn, model_state_dict[k].dtype = torch.bfloat16
Traceback (most recent call last):
  File "/root/autodl-tmp/lib/deepcompressor/test_infer.py", line 62, in <module>
    model = QwenImageModel()
            ^^^^^^^^^^^^^^^^
  File "/root/autodl-tmp/lib/deepcompressor/test_infer.py", line 44, in __init__
    self.transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(lora_file)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/lib/nunchaku/nunchaku/models/transformers/transformer_qwenimage.py", line 418, in from_pretrained
    assert t1 == t2
           ^^^^^^^^
AssertionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions