Skip to content

Commit

Permalink
WeightOnlyLinear keeps self.weight after recover (#1539)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <xin3.he@intel.com>
  • Loading branch information
xin3he committed Jan 16, 2024
1 parent f1def17 commit 2835bdb
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Expand Up @@ -259,23 +259,20 @@ def __init__(
dtype=self.float_type,
).to(device),
)
self.scales = self.scales.T
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(in_features / self.n_pack), out_features),
dtype=self.compression_dtype,
).to(device),
)
self.qweight = self.qweight.T
self.register_buffer(
"qzeros",
torch.zeros(
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
dtype=self.compression_dtype,
).to(device),
)
self.qzeros = self.qzeros.T
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
else:
self.compression_dtype = compression_dtype
Expand Down Expand Up @@ -329,6 +326,10 @@ def __init__(
self.g_idx = None

def pack(self, int_weight, scale, zp, bias, g_idx=None):
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand Down Expand Up @@ -468,12 +469,13 @@ def recover(self):
return fp32_weight

def forward(self, input):
weight = self.recover()
device = self.scales.device
if weight.dtype == torch.float16 and device.type == "cpu":
weight = weight.float()
self.bias = self.bias.float() if self.bias is not None else None
if level == DEBUG:
if not hasattr(self, "weight"):
weight = self.recover()
device = self.scales.device
if weight.dtype == torch.float16 and device.type == "cpu":
weight = weight.float()
self.bias = self.bias.float() if self.bias is not None else None
if True: # keep reusing self.weight due to recover is too slow.
if not hasattr(self, "weight"):
self.weight = weight
input = input.type(self.weight.dtype)
Expand Down

0 comments on commit 2835bdb

Please sign in to comment.