Skip to content

Fix Params4bit attribute access for FSDP state_dict traversal#1866

Open
TimDettmers wants to merge 2 commits intomainfrom
fix/issue-1405
Open

Fix Params4bit attribute access for FSDP state_dict traversal#1866
TimDettmers wants to merge 2 commits intomainfrom
fix/issue-1405

Conversation

@TimDettmers
Copy link
Collaborator

Summary

Fixes #1405

  • Add __getattr__ to Params4bit that proxies known QuantState attributes (including the quant_mapcode alias used by as_dict serialization) so that FSDP's _get_fqns() can resolve dotted FQN paths like weight.absmax and weight.quant_map
  • Add __getattr__ to QuantState that handles the packed bitsandbytes__* keys (e.g. bitsandbytes__nf4) so that the FQN path weight.quant_state.bitsandbytes__nf4 also resolves
  • Add unit tests covering attribute proxy behavior and simulated FQN traversal for all state_dict keys

Context

PyTorch's FSDP state_dict machinery (get_model_state_dict with cpu_offload=True) calls _get_fqns() which resolves dotted key paths via getattr(). For 4-bit quantized models, Linear4bit._save_to_state_dict creates keys like weight.absmax, weight.quant_map, and weight.quant_state.bitsandbytes__nf4. The _get_fqns traversal calls getattr(params4bit_obj, "absmax") which failed because absmax lives inside quant_state, not directly on the parameter.

Verified with single-GPU FSDP integration test

Without fix:

FAILED: 'Params4bit' object has no attribute 'absmax'

With fix:

SUCCESS: got state_dict with 7 keys: ['base.weight', 'base.weight.absmax', 'base.weight.quant_map', 'base.weight.nested_absmax', 'base.weight.nested_quant_map', 'base.weight.quant_state.bitsandbytes__nf4', 'adapter.weight']

Test plan

  • All 482 existing tests in test_linear4bit.py pass (zero regressions)
  • 8 new test_params4bit_quant_state_attr_access parametrized tests pass (nf4/fp4 × compress_statistics × cpu/cuda)
  • Single-GPU FSDP integration test with get_model_state_dict(cpu_offload=True) succeeds
  • CI full test suite

🤖 Generated with Claude Code

…sal (#1405)

PyTorch's FSDP state_dict machinery (_get_fqns) resolves dotted FQN paths
via getattr. For 4-bit quantized models, state_dict keys like
"weight.absmax" and "weight.quant_state.bitsandbytes__nf4" require attribute
access on Params4bit and QuantState objects that previously didn't exist.

Add __getattr__ to Params4bit that proxies known QuantState attributes
(including the quant_map→code alias used by as_dict serialization), and
add __getattr__ to QuantState that handles the packed "bitsandbytes__*"
keys. This allows FSDP's get_model_state_dict with cpu_offload=True to
traverse the full FQN namespace without AttributeError.

Verified with single-GPU FSDP integration test: without fix, fails with
"'Params4bit' object has no attribute 'absmax'"; with fix, successfully
produces all 7 state_dict keys.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Add a subprocess-based pytest test that launches a single-GPU FSDP
process via torchrun to verify get_model_state_dict with cpu_offload=True
works for QLoRA-style models with frozen Params4bit base weights.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Collaborator Author

@TimDettmers TimDettmers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Review: #1866 — Fix Params4bit attribute access for FSDP state_dict traversal

[bug-fix] [test] — Adds __getattr__ to both Params4bit and QuantState so that PyTorch's FSDP _get_fqns() traversal can resolve dotted FQN paths like weight.absmax and weight.quant_state.bitsandbytes__nf4 via getattr(). This fixes #1405, which has been open since November 2024 and affects QLoRA + FSDP users who need cpu_offload=True for large-model checkpointing.

No blocking issues.

The root cause is correctly identified and the fix is well-targeted. FSDP's _get_fqns() walks state dict key paths like weight.absmax by calling getattr(params4bit_obj, "absmax"), which previously failed because absmax lives inside params4bit.quant_state, not on the parameter itself. The __getattr__ approach is clean: it is only invoked when normal attribute lookup fails, so it does not interfere with existing attribute access, the __dict__ round-trip used by PEFT/Accelerate, __getstate__/__setstate__ serialization, or __deepcopy__.

Key observations from the review:

  1. Attribute map correctness. The _QUANT_STATE_ATTR_MAP on Params4bit correctly covers all tensor keys produced by QuantState.as_dict(packed=True): absmax, quant_map (aliased to code), nested_absmax, nested_quant_map, and nested_offset (aliased to offset). The QuantState __getattr__ correctly handles the bitsandbytes__* packed keys by delegating to as_dict(packed=True).

  2. Downstream safety. The __getattr__ is a Python fallback — it only fires when normal attribute lookup fails. This means:

    • Params4bit.__dict__ is unchanged, so the PEFT/Accelerate Params4bit(data, **old.__dict__) round-trip is unaffected.
    • isinstance checks, string-based class name checks, and all existing attribute access patterns from Transformers, PEFT, Accelerate, TGI, and vLLM continue to work exactly as before.
    • _QUANT_STATE_ATTR_MAP is a class-level attribute, so it does not appear in instance __dict__ and will not leak into the constructor round-trip.
  3. Edge case handling. When quant_state is None (before quantization), the proxy correctly falls through to AttributeError. When state2 is None (non-nested mode), accessing nested attributes like nested_absmax correctly raises AttributeError via the try/except AttributeError guard around the lambda accessor.

  4. Test quality. The unit test test_params4bit_quant_state_attr_access is thorough: it parametrizes over nf4/fp4, compress_statistics on/off, and CPU/CUDA devices. It verifies attribute proxy, the quant_map -> code alias, packed key access on QuantState, full FQN traversal simulation, hasattr correctness, unknown-attribute error raising, and non-interference with normal Params4bit attributes. The FSDP integration test exercises the real torchrun + get_model_state_dict(cpu_offload=True) path.

  5. Minor note on QuantState.getattr performance. Each call to getattr(quant_state, "bitsandbytes__nf4") invokes self.as_dict(packed=True), which allocates a dict and clones nested_quant_map. This is fine for FSDP state dict traversal (infrequent), but worth knowing if anyone ever calls it in a hot loop.

  • Security: Clear
  • Downstream impact: None (additive __getattr__ fallback — no existing attribute access patterns are affected)
  • Tests: Adequate (8 parametrized unit tests + subprocess-based FSDP integration test)
  • CI: All checks pass (lint, CPU builds across platforms, CUDA builds on L40S/T4, Windows)
  • Serialization: No changes to serialization format — as_dict(), from_dict(), _save_to_state_dict all unchanged
  • Cross-PR conflicts: File overlaps with PRs #1871, #1865, #1864, #1863, #1861, #1860, #1859, #1858 on functional.py, modules.py, or test_linear4bit.py. All are additive changes (new methods, new tests) so merge conflicts should be trivially resolvable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error when saving FSDP weights with cpu_offload=True [rank1]: AttributeError: 'Params4bit' object has no attribute 'absmax'

2 participants

Comments