Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 11, 2024
1 parent 94429b8 commit 4235c72
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 114 deletions.
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/mx_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

# pylint:disable=import-error

from .mx import mx_quantize
from .mx import mx_quantize
24 changes: 13 additions & 11 deletions neural_compressor/torch/algorithms/mx_quant/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import quantize_elemwise_op, quantize_mx_op
from typing import Dict, Tuple
from neural_compressor.torch.utils import register_algo, set_module

from neural_compressor.common.logger import Logger
from neural_compressor.common.utility import MX_QUANT
from neural_compressor.torch.quantization.config import MXQuantConfig
from neural_compressor.common.logger import Logger
from neural_compressor.torch.utils import register_algo, set_module

from .utils import quantize_elemwise_op, quantize_mx_op

logger = Logger().get_logger()


class MXLinearFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, mx_specs=None):
Expand Down Expand Up @@ -54,6 +58,7 @@ def forward(ctx, input, weight, bias=None, mx_specs=None):

return output


class MXLinear(torch.nn.Linear):
def __init__(
self,
Expand All @@ -72,14 +77,10 @@ def __init__(
def apply_mx_specs(self):
if self.mx_specs is not None:
if self.mx_specs.get("out_dtype", "float32") != "float32":
self.weight.data = quantize_elemwise_op(
self.weight.data, mx_specs=self.mx_specs
)
self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs)

if self.bias is not None:
self.bias.data = quantize_elemwise_op(
self.bias.data, mx_specs=self.mx_specs
)
self.bias.data = quantize_elemwise_op(self.bias.data, mx_specs=self.mx_specs)

# MX quantize everything along input size
self.weight.data = quantize_mx_op(
Expand All @@ -96,9 +97,10 @@ def append_name(self, postfix):
def forward(self, inputs):
if self.mx_none:
return super().forward(inputs)

return MXLinearFunction.apply(inputs, self.weight, self.bias, self.mx_specs)


def mx_quantize(
model,
config={},
Expand Down Expand Up @@ -150,4 +152,4 @@ def mx_quantize(
return new_module
else:
set_module(model, name, new_module)
return model
return model

0 comments on commit 4235c72

Please sign in to comment.