Skip to content

Commit

Permalink
Fix gptq quantization for models without bias
Browse files Browse the repository at this point in the history
  • Loading branch information
B-201 committed Apr 24, 2024
1 parent 56aabbe commit fe96b7c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,20 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
elif isinstance(layer, Conv1D):
in_features = layer.weight.shape[0]
out_features = layer.weight.shape[1]
bias = True if layer.bias else False
if not (self.desc_act) or self.group_size == -1:
new_layer = QuantLinear(
self.bits,
self.group_size,
in_features,
out_features,
True,
bias,
use_cuda_fp16=self.use_cuda_fp16,
weight_dtype=layer.weight.dtype,
)
else:
new_layer = QuantLinear(
self.bits, self.group_size, in_features, out_features, True, weight_dtype=layer.weight.dtype
self.bits, self.group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype
)
new_layer.device = device
setattr(module, attr, new_layer.to(device))
Expand Down

0 comments on commit fe96b7c

Please sign in to comment.