Fix Params4bit attribute access for FSDP state_dict traversal#1866
Fix Params4bit attribute access for FSDP state_dict traversal#1866TimDettmers wants to merge 2 commits intomainfrom
Conversation
…sal (#1405) PyTorch's FSDP state_dict machinery (_get_fqns) resolves dotted FQN paths via getattr. For 4-bit quantized models, state_dict keys like "weight.absmax" and "weight.quant_state.bitsandbytes__nf4" require attribute access on Params4bit and QuantState objects that previously didn't exist. Add __getattr__ to Params4bit that proxies known QuantState attributes (including the quant_map→code alias used by as_dict serialization), and add __getattr__ to QuantState that handles the packed "bitsandbytes__*" keys. This allows FSDP's get_model_state_dict with cpu_offload=True to traverse the full FQN namespace without AttributeError. Verified with single-GPU FSDP integration test: without fix, fails with "'Params4bit' object has no attribute 'absmax'"; with fix, successfully produces all 7 state_dict keys. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Add a subprocess-based pytest test that launches a single-GPU FSDP process via torchrun to verify get_model_state_dict with cpu_offload=True works for QLoRA-style models with frozen Params4bit base weights. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
TimDettmers
left a comment
There was a problem hiding this comment.
PR Review: #1866 — Fix Params4bit attribute access for FSDP state_dict traversal
[bug-fix] [test] — Adds __getattr__ to both Params4bit and QuantState so that PyTorch's FSDP _get_fqns() traversal can resolve dotted FQN paths like weight.absmax and weight.quant_state.bitsandbytes__nf4 via getattr(). This fixes #1405, which has been open since November 2024 and affects QLoRA + FSDP users who need cpu_offload=True for large-model checkpointing.
No blocking issues.
The root cause is correctly identified and the fix is well-targeted. FSDP's _get_fqns() walks state dict key paths like weight.absmax by calling getattr(params4bit_obj, "absmax"), which previously failed because absmax lives inside params4bit.quant_state, not on the parameter itself. The __getattr__ approach is clean: it is only invoked when normal attribute lookup fails, so it does not interfere with existing attribute access, the __dict__ round-trip used by PEFT/Accelerate, __getstate__/__setstate__ serialization, or __deepcopy__.
Key observations from the review:
-
Attribute map correctness. The
_QUANT_STATE_ATTR_MAPon Params4bit correctly covers all tensor keys produced byQuantState.as_dict(packed=True):absmax,quant_map(aliased tocode),nested_absmax,nested_quant_map, andnested_offset(aliased tooffset). The QuantState__getattr__correctly handles thebitsandbytes__*packed keys by delegating toas_dict(packed=True). -
Downstream safety. The
__getattr__is a Python fallback — it only fires when normal attribute lookup fails. This means:Params4bit.__dict__is unchanged, so the PEFT/AccelerateParams4bit(data, **old.__dict__)round-trip is unaffected.isinstancechecks, string-based class name checks, and all existing attribute access patterns from Transformers, PEFT, Accelerate, TGI, and vLLM continue to work exactly as before._QUANT_STATE_ATTR_MAPis a class-level attribute, so it does not appear in instance__dict__and will not leak into the constructor round-trip.
-
Edge case handling. When
quant_state is None(before quantization), the proxy correctly falls through toAttributeError. Whenstate2isNone(non-nested mode), accessing nested attributes likenested_absmaxcorrectly raisesAttributeErrorvia thetry/except AttributeErrorguard around the lambda accessor. -
Test quality. The unit test
test_params4bit_quant_state_attr_accessis thorough: it parametrizes overnf4/fp4,compress_statisticson/off, and CPU/CUDA devices. It verifies attribute proxy, thequant_map -> codealias, packed key access on QuantState, full FQN traversal simulation,hasattrcorrectness, unknown-attribute error raising, and non-interference with normal Params4bit attributes. The FSDP integration test exercises the realtorchrun+get_model_state_dict(cpu_offload=True)path. -
Minor note on QuantState.getattr performance. Each call to
getattr(quant_state, "bitsandbytes__nf4")invokesself.as_dict(packed=True), which allocates a dict and clonesnested_quant_map. This is fine for FSDP state dict traversal (infrequent), but worth knowing if anyone ever calls it in a hot loop.
- Security: Clear
- Downstream impact: None (additive
__getattr__fallback — no existing attribute access patterns are affected) - Tests: Adequate (8 parametrized unit tests + subprocess-based FSDP integration test)
- CI: All checks pass (lint, CPU builds across platforms, CUDA builds on L40S/T4, Windows)
- Serialization: No changes to serialization format —
as_dict(),from_dict(),_save_to_state_dictall unchanged - Cross-PR conflicts: File overlaps with PRs #1871, #1865, #1864, #1863, #1861, #1860, #1859, #1858 on
functional.py,modules.py, ortest_linear4bit.py. All are additive changes (new methods, new tests) so merge conflicts should be trivially resolvable.
Summary
Fixes #1405
__getattr__toParams4bitthat proxies knownQuantStateattributes (including thequant_map→codealias used byas_dictserialization) so that FSDP's_get_fqns()can resolve dotted FQN paths likeweight.absmaxandweight.quant_map__getattr__toQuantStatethat handles the packedbitsandbytes__*keys (e.g.bitsandbytes__nf4) so that the FQN pathweight.quant_state.bitsandbytes__nf4also resolvesContext
PyTorch's FSDP state_dict machinery (
get_model_state_dictwithcpu_offload=True) calls_get_fqns()which resolves dotted key paths viagetattr(). For 4-bit quantized models,Linear4bit._save_to_state_dictcreates keys likeweight.absmax,weight.quant_map, andweight.quant_state.bitsandbytes__nf4. The_get_fqnstraversal callsgetattr(params4bit_obj, "absmax")which failed becauseabsmaxlives insidequant_state, not directly on the parameter.Verified with single-GPU FSDP integration test
Without fix:
With fix:
Test plan
test_linear4bit.pypass (zero regressions)test_params4bit_quant_state_attr_accessparametrized tests pass (nf4/fp4 × compress_statistics × cpu/cuda)get_model_state_dict(cpu_offload=True)succeeds🤖 Generated with Claude Code