-
-
Notifications
You must be signed in to change notification settings - Fork 838
Description
System Info
PR #1866 added __getattr__ to Params4bit for FSDP state_dict support. This works fine for FSDP, but it breaks torch.compile.
Params4bit is a torch.Tensor subclass. PyTorch's Dynamo (the compiler frontend) doesn't know how to trace tensor subclasses that define __getattr__, so it creates graph breaks whenever it encounters attribute access on such objects. With activation checkpointing, these graph breaks multiply across layers, resulting in many more subgraphs than necessary and significant compilation overhead.
We noticed this when running torch.compile on a QLoRA fine-tuning workload (LLaMA 70B, HuggingFace, activation checkpointing). With __getattr__ present, we saw significant performance degradation caused by graph breaks. Removing __getattr__ (or replacing with @property as proposed below) restores expected performance.
Reproduction
I put together a minimal repro using torch._dynamo.explain() on a small model with Linear4bit and activation checkpointing. With __getattr__ present it showed graph breaks; after removing it, they were gone. The script is rough, so I'd appreciate it if you could verify this with a more representative test — or suggest one if you have something better suited.
Expected behavior
Replace __getattr__ + _QUANT_STATE_ATTR_MAP with @property descriptors. Properties are resolved at the class level through Python's descriptor protocol — Dynamo handles them fine, no graph breaks. FSDP still works because getattr(weight, "absmax") resolves the same way. Example:
@property
def absmax(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.absmax
raise AttributeError("'Params4bit' object has no attribute 'absmax'")
@property
def code(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.code
raise AttributeError("'Params4bit' object has no attribute 'code'")
@property
def quant_map(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.code
raise AttributeError("'Params4bit' object has no attribute 'quant_map'")
@property
def offset(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.offset
raise AttributeError("'Params4bit' object has no attribute 'offset'")
@property
def state2(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.state2
raise AttributeError("'Params4bit' object has no attribute 'state2'")
@property
def nested_absmax(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.absmax
raise AttributeError("'Params4bit' object has no attribute 'nested_absmax'")
@property
def nested_blocksize(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.blocksize
raise AttributeError("'Params4bit' object has no attribute 'nested_blocksize'")
@property
def nested_quant_map(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.code
raise AttributeError("'Params4bit' object has no attribute 'nested_quant_map'")
@property
def nested_dtype(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.dtype
raise AttributeError("'Params4bit' object has no attribute 'nested_dtype'")
@property
def nested_offset(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.offset
raise AttributeError("'Params4bit' object has no attribute 'nested_offset'")