Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,18 @@ def __init__(
self.state2 = state2
self.nested = state2 is not None

def __getattr__(self, name):
# Support attribute access for packed state_dict keys like "bitsandbytes__nf4".
# PyTorch's FSDP state_dict traversal (_get_fqns) resolves dotted FQN paths via
# getattr. The packed key "quant_state.bitsandbytes__nf4" causes it to call
# getattr(quant_state_obj, "bitsandbytes__nf4"), which we handle here.
if name.startswith("bitsandbytes__"):
qs_dict = self.as_dict(packed=True)
packed_key = "quant_state." + name
if packed_key in qs_dict:
return qs_dict[packed_key]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __getitem__(self, idx):
"""
ensures compatibility with older quant state scheme with nested lists.
Expand Down
37 changes: 37 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,43 @@ def __setstate__(self, state):
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]

# Map from state_dict key names (as produced by QuantState.as_dict) to
# the actual QuantState attribute/access path. FSDP's _get_fqns() resolves
# dotted FQN keys via getattr, so "weight.quant_map" becomes
# getattr(weight, "quant_map") — we must map that to quant_state.code.
_QUANT_STATE_ATTR_MAP = {
# Direct QuantState attributes
"absmax": lambda qs: qs.absmax,
"code": lambda qs: qs.code,
"blocksize": lambda qs: qs.blocksize,
"dtype": lambda qs: qs.dtype,
"shape": lambda qs: qs.shape,
"offset": lambda qs: qs.offset,
"state2": lambda qs: qs.state2,
# as_dict serializes code → "quant_map"
"quant_map": lambda qs: qs.code,
"quant_type": lambda qs: qs.quant_type,
# as_dict serializes nested state2 attributes under "nested_*" keys
"nested_absmax": lambda qs: qs.state2.absmax,
"nested_blocksize": lambda qs: qs.state2.blocksize,
"nested_quant_map": lambda qs: qs.state2.code,
"nested_dtype": lambda qs: qs.state2.dtype,
"nested_offset": lambda qs: qs.offset,
}

def __getattr__(self, name):
# Proxy known QuantState attributes so that PyTorch's FSDP state_dict
# machinery (which traverses FQN paths via getattr) can find them.
accessor = self._QUANT_STATE_ATTR_MAP.get(name)
if accessor is not None:
quant_state = self.__dict__.get("quant_state")
if quant_state is not None:
try:
return accessor(quant_state)
except AttributeError:
pass
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
Expand Down
80 changes: 80 additions & 0 deletions tests/fsdp_state_dict_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""FSDP state_dict save integration test for 4-bit quantized models (#1405).

This script must be launched via torchrun (not directly):
torchrun --nproc_per_node=1 tests/fsdp_state_dict_save.py

It wraps a QLoRA-style model (frozen 4-bit base + trainable adapter) in FSDP
and calls get_model_state_dict with cpu_offload=True, which exercises the
_get_fqns() getattr traversal that previously crashed with:
AttributeError: 'Params4bit' object has no attribute 'absmax'
"""

import sys

import torch
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.nn as nn

import bitsandbytes as bnb


class SimpleQLoRAModel(nn.Module):
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""

def __init__(self, quant_type="nf4"):
super().__init__()
self.base = bnb.nn.Linear4bit(64, 64, bias=False, quant_type=quant_type)
self.adapter = nn.Linear(64, 64, bias=False)

def forward(self, x):
return self.base(x) + self.adapter(x)


def main():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)

errors = []

for quant_type in ("nf4", "fp4"):
model = SimpleQLoRAModel(quant_type=quant_type)
model = model.to("cuda")

# Freeze quantized base weights (as in real QLoRA)
for p in model.base.parameters():
p.requires_grad = False

# Tell FSDP to ignore the frozen quantized params (can't flatten int dtypes)
ignored = list(model.base.parameters())
fsdp_model = FSDP(model, device_id=rank, ignored_states=ignored, use_orig_params=True)

options = StateDictOptions(full_state_dict=True, cpu_offload=True)
try:
state_dict = get_model_state_dict(fsdp_model, options=options)

# Verify expected keys are present
expected_substrings = ["base.weight", "absmax", "quant_map", "adapter.weight"]
for substr in expected_substrings:
if not any(substr in k for k in state_dict.keys()):
errors.append(f"{quant_type}: missing key containing '{substr}' in {list(state_dict.keys())}")

print(f"{quant_type}: SUCCESS ({len(state_dict)} keys)", flush=True)
except Exception as e:
errors.append(f"{quant_type}: {type(e).__name__}: {e}")
print(f"{quant_type}: FAILED: {e}", flush=True)

dist.destroy_process_group()

if errors:
print("\nFAILURES:\n" + "\n".join(errors), file=sys.stderr, flush=True)
sys.exit(1)
else:
print("\nAll FSDP state_dict tests passed.", flush=True)
sys.exit(0)


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import os
import pathlib
import pickle
import platform
import subprocess
import sys
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -431,3 +433,96 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
grad_compiled = x.grad.clone()

torch.testing.assert_close(grad_compiled, grad_ref)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_quant_state_attr_access(device, quant_type, compress_statistics):
"""Test that Params4bit proxies QuantState attributes for FSDP state_dict traversal (#1405).

PyTorch's FSDP state_dict machinery traverses FQN paths like
'model.layers.0.weight.absmax' using getattr(). This test verifies
that Params4bit and QuantState expose the attributes that appear as
state_dict keys so that _get_fqns() traversal succeeds.
"""
if device == "hpu" and not is_supported_on_hpu(quant_type):
pytest.skip("This configuration is not supported on HPU.")

layer = bnb.nn.Linear4bit(
64,
64,
bias=False,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
layer = layer.to(device)
w = layer.weight

assert w.quant_state is not None, "quant_state should be set after quantization"

# Direct QuantState attributes proxied through Params4bit
assert torch.equal(w.absmax, w.quant_state.absmax)
assert torch.equal(w.code, w.quant_state.code)

# "quant_map" is how as_dict() serializes "code" — FSDP uses this key name
assert torch.equal(w.quant_map, w.quant_state.code)

# QuantState packed key: as_dict(packed=True) produces "quant_state.bitsandbytes__<type>"
# FSDP resolves this as getattr(quant_state_obj, "bitsandbytes__<type>")
packed_attr = f"bitsandbytes__{quant_type}"
assert hasattr(w.quant_state, packed_attr)
packed_val = getattr(w.quant_state, packed_attr)
assert isinstance(packed_val, torch.Tensor)

# Simulate the full FSDP _get_fqns traversal for all state_dict keys
state_dict_keys = list(w.quant_state.as_dict(packed=True).keys())
for key in state_dict_keys:
# Each key is relative to "weight.", e.g. "absmax" or "quant_state.bitsandbytes__nf4"
parts = key.split(".")
obj = w
for part in parts:
obj = getattr(obj, part)
assert obj is not None

# hasattr should return True for proxied attrs, False for unknown ones
assert hasattr(w, "absmax")
assert hasattr(w, "code")
assert hasattr(w, "quant_map")
assert not hasattr(w, "nonexistent_attribute")

# Unknown attributes must still raise AttributeError
with pytest.raises(AttributeError, match="nonexistent_attribute"):
_ = w.nonexistent_attribute

# Verify that normal Params4bit attributes are unaffected by __getattr__
assert isinstance(w.quant_state, bnb.functional.QuantState)
assert isinstance(w.bnb_quantized, bool)
assert w.bnb_quantized is True


@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
@pytest.mark.skipif(
not torch.distributed.is_nccl_available(),
reason="FSDP test requires NCCL backend",
)
def test_fsdp_state_dict_save_4bit():
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).

Launches a single-GPU FSDP process via torchrun to exercise the real
_get_fqns() code path that previously crashed with:
AttributeError: 'Params4bit' object has no attribute 'absmax'
"""
script = pathlib.Path(__file__).with_name("fsdp_state_dict_save.py")
result = subprocess.run(
["torchrun", "--nproc_per_node=1", str(script)],
capture_output=True,
text=True,
timeout=120,
)
if result.returncode != 0:
pytest.fail(
f"FSDP state_dict test failed (exit {result.returncode}):\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)