Skip to content

Commit

Permalink
git rid of tmp linear and use module instead
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Mar 14, 2024
1 parent c98b8e2 commit b7ced31
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,14 @@
from ..modeling_utils import PreTrainedModel

from ..integrations import prepare_for_hqq_linear
from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging
from ..utils import is_hqq_available, is_torch_available, logging
from ..utils.hqq_utils import find_parent
from .quantizers_utils import get_module_from_name


if is_torch_available():
import torch

if is_accelerate_available():
from accelerate import init_empty_weights

if is_hqq_available():
from hqq.core.quantize import HQQLinear
else:
Expand Down Expand Up @@ -116,23 +113,19 @@ def create_quantized_param(

# Step 1: Check if the state_dict of the module already contains quantized parameters
if ("W_q" in module_state_dict) and ("meta" in module_state_dict):
module = HQQLinear(
module_hqq = HQQLinear(
linear_layer=None, quant_config=None, compute_dtype=self.torch_dtype, device=target_device
)
module.load_state_dict(module_state_dict)
module_hqq.load_state_dict(module_state_dict)
setattr(parent_module, node, module_hqq)
return

# Step 2: Create tmp linear layer on meta then feed the dictionary.
with init_empty_weights():
tmp_linear_layer = torch.nn.Linear(
in_features=module.in_features, out_features=module.out_features, bias=module.bias
)

# Step 2: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(tmp_linear_layer, key, torch.nn.Parameter(module_state_dict[key]))
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))

"""
Step 3: Replace tmp_linear_layer with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
Step 3: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
directly doesn't work.
"""

Expand All @@ -141,17 +134,15 @@ def create_quantized_param(
parent_module,
node,
HQQLinear(
tmp_linear_layer,
module,
module.quant_config,
compute_dtype=self.torch_dtype,
device=target_device,
del_orig=True,
),
)
else:
setattr(parent_module, node, tmp_linear_layer.to(self.torch_dtype).to(target_device))

del tmp_linear_layer
setattr(parent_module, node, module.to(self.torch_dtype).to(target_device))

torch.cuda.empty_cache()

Expand Down

0 comments on commit b7ced31

Please sign in to comment.