Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from ..quantized_func_wrapper import QuantizedFuncWrapperBase, OP_TYPE, QuantizedFuncWrapperFactory

import torch
import torchao
from torchao.quantization.quant_primitives import (
_quantize_affine_float8,
_dequantize_affine_float8,
)

from abc import ABCMeta

Expand All @@ -32,7 +35,7 @@ def __init__(self, scale_format, is_dynamic=False):
class QuantizedCPUQuant(QuantizedCPUFuncWrapperBase):

def get_default_quantized_func(self):
return torch.ops.torchao.quantize_affine_float8
return _quantize_affine_float8

def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
return self._quantized_func_(tensor=input, scale=torch.tensor(scale), float8_dtype=dtype)
Expand All @@ -41,7 +44,7 @@ def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_
class QuantizedCPUQuantPC(QuantizedCPUFuncWrapperBase):

def get_default_quantized_func(self):
return torch.ops.torchao.quantize_affine_float8
return _quantize_affine_float8

def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
return self._quantized_func_(tensor=input, scale=scale.view((-1, 1)), float8_dtype=dtype)
Expand All @@ -50,7 +53,7 @@ def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_
class QuantizedCPUDeQuant(QuantizedCPUFuncWrapperBase):

def get_default_quantized_func(self):
return torch.ops.torchao.dequantize_affine_float8
return _dequantize_affine_float8

def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
return self._quantized_func_(tensor=input, scale=torch.tensor(scale), output_dtype=out_dtype)
Expand All @@ -59,7 +62,7 @@ def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_
class QuantizedCPUDeQuantPC(QuantizedCPUFuncWrapperBase):

def get_default_quantized_func(self):
return torch.ops.torchao.dequantize_affine_float8
return _dequantize_affine_float8

def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
return self._quantized_func_(tensor=input, scale=scale.view((1, -1)), output_dtype=out_dtype)
Expand Down
Loading