From 4235c72c45a4e28664c8ce4f458b0997b62ddb3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 06:41:31 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/mx_quant/__init__.py | 2 +- .../torch/algorithms/mx_quant/mx.py | 24 ++- .../torch/algorithms/mx_quant/utils.py | 197 +++++++++--------- .../torch/quantization/algorithm_entry.py | 17 +- .../torch/quantization/config.py | 36 +++- test/3x/torch/quantization/test_mx_quant.py | 2 +- 6 files changed, 164 insertions(+), 114 deletions(-) diff --git a/neural_compressor/torch/algorithms/mx_quant/__init__.py b/neural_compressor/torch/algorithms/mx_quant/__init__.py index 78a84ad773a..cf224afdb87 100644 --- a/neural_compressor/torch/algorithms/mx_quant/__init__.py +++ b/neural_compressor/torch/algorithms/mx_quant/__init__.py @@ -14,4 +14,4 @@ # pylint:disable=import-error -from .mx import mx_quantize \ No newline at end of file +from .mx import mx_quantize diff --git a/neural_compressor/torch/algorithms/mx_quant/mx.py b/neural_compressor/torch/algorithms/mx_quant/mx.py index 34019575a5e..60c509fb066 100644 --- a/neural_compressor/torch/algorithms/mx_quant/mx.py +++ b/neural_compressor/torch/algorithms/mx_quant/mx.py @@ -18,14 +18,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import quantize_elemwise_op, quantize_mx_op from typing import Dict, Tuple -from neural_compressor.torch.utils import register_algo, set_module + +from neural_compressor.common.logger import Logger from neural_compressor.common.utility import MX_QUANT from neural_compressor.torch.quantization.config import MXQuantConfig -from neural_compressor.common.logger import Logger +from neural_compressor.torch.utils import register_algo, set_module + +from .utils import quantize_elemwise_op, quantize_mx_op logger = Logger().get_logger() + + class MXLinearFunction(Function): @staticmethod def forward(ctx, input, weight, bias=None, mx_specs=None): @@ -54,6 +58,7 @@ def forward(ctx, input, weight, bias=None, mx_specs=None): return output + class MXLinear(torch.nn.Linear): def __init__( self, @@ -72,14 +77,10 @@ def __init__( def apply_mx_specs(self): if self.mx_specs is not None: if self.mx_specs.get("out_dtype", "float32") != "float32": - self.weight.data = quantize_elemwise_op( - self.weight.data, mx_specs=self.mx_specs - ) + self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs) if self.bias is not None: - self.bias.data = quantize_elemwise_op( - self.bias.data, mx_specs=self.mx_specs - ) + self.bias.data = quantize_elemwise_op(self.bias.data, mx_specs=self.mx_specs) # MX quantize everything along input size self.weight.data = quantize_mx_op( @@ -96,9 +97,10 @@ def append_name(self, postfix): def forward(self, inputs): if self.mx_none: return super().forward(inputs) - + return MXLinearFunction.apply(inputs, self.weight, self.bias, self.mx_specs) + def mx_quantize( model, config={}, @@ -150,4 +152,4 @@ def mx_quantize( return new_module else: set_module(model, name, new_module) - return model \ No newline at end of file + return model diff --git a/neural_compressor/torch/algorithms/mx_quant/utils.py b/neural_compressor/torch/algorithms/mx_quant/utils.py index c5c1e3db625..aef7f100f13 100644 --- a/neural_compressor/torch/algorithms/mx_quant/utils.py +++ b/neural_compressor/torch/algorithms/mx_quant/utils.py @@ -20,11 +20,13 @@ from enum import Enum, IntEnum + import torch FP32_EXPONENT_BIAS = 127 FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) + class ElemFormat(Enum): int8 = 1 int4 = 2 @@ -42,50 +44,50 @@ class ElemFormat(Enum): @staticmethod def from_str(s): - assert(s != None), "String elem_format == None" + assert s != None, "String elem_format == None" s = s.lower() if hasattr(ElemFormat, s): return getattr(ElemFormat, s) else: raise Exception("Undefined elem format", s) - + @staticmethod def is_bf(s): if isinstance(s, str): - assert(s != None), "String elem_format == None" + assert s != None, "String elem_format == None" s = s.lower() if hasattr(ElemFormat, s): return getattr(ElemFormat, s).value == 10 elif isinstance(s, int): return s == 10 - + raise Exception("Undefined elem format", s) @staticmethod def is_fp(s): if isinstance(s, str): - assert(s != None), "String elem_format == None" + assert s != None, "String elem_format == None" s = s.lower() if hasattr(ElemFormat, s): return 4 <= getattr(ElemFormat, s).value <= 9 elif isinstance(s, int): return 4 <= s <= 9 - - raise Exception("Undefined elem format", s) + raise Exception("Undefined elem format", s) @staticmethod def is_int(s): if isinstance(s, str): - assert(s != None), "String elem_format == None" + assert s != None, "String elem_format == None" s = s.lower() if hasattr(ElemFormat, s): return 1 <= getattr(ElemFormat, s).value <= 3 elif isinstance(s, int): return 1 <= s <= 3 - + raise Exception("Undefined elem format", s) + class RoundingMode(IntEnum): nearest = 0 floor = 1 @@ -95,33 +97,38 @@ class RoundingMode(IntEnum): def string_enums(): return [s.name for s in list(RoundingMode)] + def _get_min_norm(ebits): - """ Valid for all float formats """ + """Valid for all float formats.""" emin = 2 - (2 ** (ebits - 1)) - return 0 if ebits == 0 else 2 ** emin + return 0 if ebits == 0 else 2**emin + def _get_max_norm(ebits, mbits): - """ Valid only for floats that define NaN """ - assert(ebits >= 5), "invalid for floats that don't define NaN" - emax = 0 if ebits==0 else 2**(ebits - 1) - 1 - return 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2) + """Valid only for floats that define NaN.""" + assert ebits >= 5, "invalid for floats that don't define NaN" + emax = 0 if ebits == 0 else 2 ** (ebits - 1) - 1 + return 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) + _FORMAT_CACHE = {} + + def _get_format_params(fmt): - """ Allowed formats: - - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation - - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf - - bfloatX/bfX: 9 <= X <= 32 - - fp4, no NaN/Inf - - fp6_e3m2/e2m3, no NaN/Inf - - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior - - Returns: - ebits: exponent bits - mbits: mantissa bits: includes sign and implicit bits - emax: max normal exponent - max_norm: max normal number - min_norm: min normal number + """Allowed formats: + - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation + - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf + - bfloatX/bfX: 9 <= X <= 32 + - fp4, no NaN/Inf + - fp6_e3m2/e2m3, no NaN/Inf + - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior + + Returns: + ebits: exponent bits + mbits: mantissa bits: includes sign and implicit bits + emax: max normal exponent + max_norm: max normal number + min_norm: min normal number """ if type(fmt) is str: fmt = ElemFormat.from_str(fmt) @@ -140,30 +147,30 @@ def _get_format_params(fmt): emax = 0 elif fmt == ElemFormat.fp8_e5m2: ebits, mbits = 5, 4 - emax = 2**(ebits - 1) - 1 + emax = 2 ** (ebits - 1) - 1 elif fmt == ElemFormat.fp8_e4m3: ebits, mbits = 4, 5 - emax = 2**(ebits - 1) + emax = 2 ** (ebits - 1) elif fmt == ElemFormat.fp6_e3m2: ebits, mbits = 3, 4 - emax = 2**(ebits - 1) + emax = 2 ** (ebits - 1) elif fmt == ElemFormat.fp6_e2m3: ebits, mbits = 2, 5 - emax = 2**(ebits - 1) + emax = 2 ** (ebits - 1) elif fmt == ElemFormat.fp4: ebits, mbits = 2, 3 - emax = 2**(ebits - 1) + emax = 2 ** (ebits - 1) elif fmt == ElemFormat.float16: ebits, mbits = 5, 12 - emax = 2**(ebits - 1) - 1 + emax = 2 ** (ebits - 1) - 1 elif fmt == ElemFormat.bfloat16: ebits, mbits = 8, 9 - emax = 2**(ebits - 1) - 1 + emax = 2 ** (ebits - 1) - 1 else: raise Exception("Unknown element format %s" % fmt) - + if fmt != ElemFormat.fp8_e4m3: - max_norm = 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2) + max_norm = 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) else: max_norm = 2**emax * 1.75 # FP8 has custom max_norm @@ -173,19 +180,22 @@ def _get_format_params(fmt): return ebits, mbits, emax, max_norm, min_norm + # Never explicitly compute 2**(-exp) since subnorm numbers have # exponents smaller than -126 def _safe_lshift(x, bits, exp): if exp is None: return x * (2**bits) else: - return x / (2 ** exp) * (2**bits) + return x / (2**exp) * (2**bits) + def _safe_rshift(x, bits, exp): if exp is None: return x / (2**bits) else: - return x / (2**bits) * (2 ** exp) + return x / (2**bits) * (2**exp) + def _round_mantissa(A, bits, round, clamp=False): """ @@ -220,9 +230,10 @@ def _round_mantissa(A, bits, round, clamp=False): A = torch.clamp(A, -max_mantissa, max_mantissa) return A + def _shared_exponents(A, method="max", axes=None, ebits=0): - """ - Get shared exponents for the passed matrix A. + """Get shared exponents for the passed matrix A. + Args: A {PyTorch tensor} -- Input tensor method {str} -- Exponent selection method. @@ -247,16 +258,12 @@ def _shared_exponents(A, method="max", axes=None, ebits=0): raise Exception("Unrecognized shared exponent selection method %s" % (method)) # log2(shared_exp) and truncate to integer - shared_exp = torch.floor( - torch.log2( - shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) - ) - ) + shared_exp = torch.floor(torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype))) # Restrict to [-emax, emax] range if ebits > 0: - emax = 2**(ebits-1) - 1 - #shared_exp = torch.clamp(shared_exp, -emax, emax) + emax = 2 ** (ebits - 1) - 1 + # shared_exp = torch.clamp(shared_exp, -emax, emax) # Overflow to Inf shared_exp[shared_exp > emax] = float("NaN") # Underflows are set to -127 which causes them to be @@ -265,12 +272,10 @@ def _shared_exponents(A, method="max", axes=None, ebits=0): return shared_exp + def _reshape_to_blocks(A, axes, block_size): if axes is None: - raise Exception( - "axes required in order to determine which " - "dimension to apply block size to" - ) + raise Exception("axes required in order to determine which " "dimension to apply block size to") if block_size == 0: raise Exception("block_size == 0 in _reshape_to_blocks") @@ -326,6 +331,7 @@ def _reshape(shape, reshape_block_size): A = A.view(reshape) return A, axes, orig_shape, padded_shape + def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes): # Undo tile reshaping A = A.view(padded_shape) @@ -338,9 +344,9 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes): A = torch.squeeze(A, dim=axis + 1) return A -def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', - saturate_normals=False, allow_denorm=True): - """ Core function used for element-wise quantization + +def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", saturate_normals=False, allow_denorm=True): + """Core function used for element-wise quantization Arguments: A {PyTorch tensor} -- A tensor to be quantized bits {int} -- Number of mantissa bits. Includes @@ -364,11 +370,10 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', out = A if exp_bits != 0: - private_exp = torch.floor(torch.log2( - torch.abs(A) + (A == 0).type(A.dtype))) + private_exp = torch.floor(torch.log2(torch.abs(A) + (A == 0).type(A.dtype))) # The minimum representable exponent for 8 exp bits is -126 - min_exp = -(2**(exp_bits-1)) + 2 + min_exp = -(2 ** (exp_bits - 1)) + 2 private_exp = private_exp.clip(min=min_exp) else: private_exp = None @@ -385,8 +390,7 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', if saturate_normals or exp_bits == 0: out = torch.clamp(out, min=-max_norm, max=max_norm) else: - out = torch.where((torch.abs(out) > max_norm), - torch.sign(out) * float("Inf"), out) + out = torch.where((torch.abs(out) > max_norm), torch.sign(out) * float("Inf"), out) # handle Inf/NaN out[A == float("Inf")] = float("Inf") @@ -395,9 +399,11 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', return out -def _quantize_fp(A, exp_bits=None, mantissa_bits=None, - round='nearest', allow_denorm=True): - """ Quantize values to IEEE fpX format. The format defines NaN/Inf + +def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_denorm=True): + """Quantize values to IEEE fpX format. + + The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16. Arguments: exp_bits {int} -- number of bits used to store exponent @@ -409,16 +415,17 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, if exp_bits is None or mantissa_bits is None: return A - max_norm = _get_max_norm(exp_bits, mantissa_bits+2) + max_norm = _get_max_norm(exp_bits, mantissa_bits + 2) output = _quantize_elemwise_core( - A, bits=mantissa_bits + 2, exp_bits=exp_bits, - max_norm=max_norm, round=round, allow_denorm=allow_denorm) + A, bits=mantissa_bits + 2, exp_bits=exp_bits, max_norm=max_norm, round=round, allow_denorm=allow_denorm + ) return output -def _quantize_bfloat(A, bfloat, round='nearest', allow_denorm=True): - """ Quantize values to bfloatX format + +def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True): + """Quantize values to bfloatX format Arguments: bfloat {int} -- Total number of bits for bfloatX format, Includes 1 sign, 8 exp bits, and variable @@ -428,11 +435,12 @@ def _quantize_bfloat(A, bfloat, round='nearest', allow_denorm=True): if bfloat == 0 or bfloat == 32: return A - max_norm = _get_max_norm(8, bfloat-7) + max_norm = _get_max_norm(8, bfloat - 7) return _quantize_elemwise_core( - A, bits=bfloat-7, exp_bits=8, max_norm=max_norm, round=round, - allow_denorm=allow_denorm) + A, bits=bfloat - 7, exp_bits=8, max_norm=max_norm, round=round, allow_denorm=allow_denorm + ) + def quantize_elemwise_op(A, mx_specs): """A function used for element-wise quantization with mx_specs @@ -450,32 +458,30 @@ def quantize_elemwise_op(A, mx_specs): elem_format = ElemFormat.from_str(out_dtype) ebits, mbits, _, _, _ = _get_format_params(elem_format) if ElemFormat.is_bf(out_dtype): - A = _quantize_bfloat(A, bfloat=ebits+mbits-1, round=round, - allow_denorm=True) + A = _quantize_bfloat(A, bfloat=ebits + mbits - 1, round=round, allow_denorm=True) elif ElemFormat.is_fp(out_dtype): - A = _quantize_fp(A, exp_bits=5, mantissa_bits=ebits+mbits-1 - 6, - round=round, allow_denorm=True) + A = _quantize_fp(A, exp_bits=5, mantissa_bits=ebits + mbits - 1 - 6, round=round, allow_denorm=True) else: raise ValueError("Cannot set {} for output dtype.".format(out_dtype)) return A + def _quantize_mx( A, scale_bits, - elem_format, # can be None for no quantization + elem_format, # can be None for no quantization shared_exp_method="max", axes=None, block_size=32, round="nearest", flush_fp32_subnorms=False, ): - """Function used for MX* quantization - """ + """Function used for MX* quantization.""" # Shortcut for no quantization if elem_format == None: return A - assert(scale_bits > 0) + assert scale_bits > 0 # Make sure axes is a list of non-negative numbers axes = [axes] if type(axes) == int else axes @@ -484,9 +490,7 @@ def _quantize_mx( ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format) # Perform tiling to the hardware vector size - A, axes, orig_shape, padded_shape = _reshape_to_blocks( - A, axes, block_size - ) + A, axes, orig_shape, padded_shape = _reshape_to_blocks(A, axes, block_size) #################### # Quantize @@ -495,7 +499,10 @@ def _quantize_mx( # Get shared exponents shared_exp = _shared_exponents( - A, method=shared_exp_method, axes=shared_exp_axes, ebits=0, + A, + method=shared_exp_method, + axes=shared_exp_axes, + ebits=0, ) # Flush subnormal FP32 inputs to zero @@ -506,15 +513,13 @@ def _quantize_mx( # in the element data format shared_exp = shared_exp - emax - scale_emax = 2**(scale_bits-1) - 1 + scale_emax = 2 ** (scale_bits - 1) - 1 shared_exp[shared_exp > scale_emax] = float("NaN") shared_exp[shared_exp < -scale_emax] = -scale_emax A = A / (2**shared_exp) - A = _quantize_elemwise_core( - A, mbits, ebits, max_norm, round=round, - allow_denorm=True, saturate_normals=True) + A = _quantize_elemwise_core(A, mbits, ebits, max_norm, round=round, allow_denorm=True, saturate_normals=True) A = A * (2**shared_exp) @@ -523,6 +528,7 @@ def _quantize_mx( return A + def quantize_mx_op( A, elem_format: str, @@ -538,9 +544,12 @@ def quantize_mx_op( elem_format = ElemFormat.from_str(elem_format) return _quantize_mx( - A, scale_bits, - elem_format, block_size=block_size, - axes=axes, round=round, - shared_exp_method="max", - flush_fp32_subnorms=False) - + A, + scale_bits, + elem_format, + block_size=block_size, + axes=axes, + round=round, + shared_exp_method="max", + flush_fp32_subnorms=False, + ) diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index c9a8aadf806..1a80b8a327b 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -18,18 +18,29 @@ import torch -from neural_compressor.common.utils import AUTOROUND, AWQ, FP8_QUANT, GPTQ, HQQ, RTN, SMOOTH_QUANT, STATIC_QUANT, TEQ, MX_QUANT +from neural_compressor.common.utils import ( + AUTOROUND, + AWQ, + FP8_QUANT, + GPTQ, + HQQ, + MX_QUANT, + RTN, + SMOOTH_QUANT, + STATIC_QUANT, + TEQ, +) from neural_compressor.torch.quantization import ( AutoRoundConfig, AWQConfig, FP8Config, GPTQConfig, HQQConfig, + MXQuantConfig, RTNConfig, SmoothQuantConfig, StaticQuantConfig, TEQConfig, - MXQuantConfig, ) from neural_compressor.torch.utils import logger, register_algo @@ -455,4 +466,4 @@ def mx_quant_entry( } model = mx_quantize(model, weight_config=weight_config) - return model \ No newline at end of file + return model diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 34b4643e66f..6532a9deed9 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -34,12 +34,12 @@ FP8_QUANT, GPTQ, HQQ, + MX_QUANT, OP_NAME_OR_MODULE_TYPE, RTN, SMOOTH_QUANT, STATIC_QUANT, TEQ, - MX_QUANT, ) from neural_compressor.torch.utils import is_hpex_available, logger from neural_compressor.torch.utils.constants import ( @@ -813,8 +813,32 @@ def __init__( def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] linear_mx_config = MXQuantConfig( - w_dtype=["int8", "int4", "int2", "fp8_e5m2", "fp8_e4m3", "fp6_e3m2", "fp6_e2m3", "fp4", "float16", "bfloat16", "float32"], - act_dtype=["int8", "int4", "int2", "fp8_e5m2", "fp8_e4m3", "fp6_e3m2", "fp6_e2m3", "fp4", "float16", "bfloat16", "float32"], + w_dtype=[ + "int8", + "int4", + "int2", + "fp8_e5m2", + "fp8_e4m3", + "fp6_e3m2", + "fp6_e2m3", + "fp4", + "float16", + "bfloat16", + "float32", + ], + act_dtype=[ + "int8", + "int4", + "int2", + "fp8_e5m2", + "fp8_e4m3", + "fp6_e3m2", + "fp6_e2m3", + "fp4", + "float16", + "bfloat16", + "float32", + ], out_dtype=["bfloat16", "float16", "float32"], blocksize=[2, 4, 8, 16, 32, 64, 128, 256, 512], round_method=["nearest", "dither", "floor", "even"], @@ -826,7 +850,10 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear, torch.nn.functional.linear,) + white_list = ( + torch.nn.Linear, + torch.nn.functional.linear, + ) filter_result = [] for op_name, module in model.named_modules(): @@ -840,6 +867,7 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: def get_config_set_for_tuning(cls) -> Union[None, "MXQuantConfig", List["MXQuantConfig"]]: return MXQuantConfig(weight_only=[False, True]) + def get_default_mx_config() -> MXQuantConfig: """Generate the default mx config. diff --git a/test/3x/torch/quantization/test_mx_quant.py b/test/3x/torch/quantization/test_mx_quant.py index ff92a9d4387..9122e371235 100644 --- a/test/3x/torch/quantization/test_mx_quant.py +++ b/test/3x/torch/quantization/test_mx_quant.py @@ -77,4 +77,4 @@ def forward(self, x): output1 = fp32_model(example_inputs) output2 = q_model(example_inputs) # set a big atol to avoid random issue - assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check." \ No newline at end of file + assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."