Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sub-channel quantized type implementation #120172

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sdasgup3
Copy link
Contributor

This is an implementation for RFC: Supporting Sub-Channel Quantization in MLIR.

In order to make the review process easier, the PR has been divided into the following commit labels:

  1. Add implementation for sub-channel type: Includes the class design for UniformQuantizedSubChannelType, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered.
  2. Add implementation for sub-channel type: Lowering of quant.qcast and quant.dcast operations to Linalg operations.
  3. Adding C/Python Apis: We first define he C-APIs and build the Python-APIs on top of those.
  4. Add pass to normalize generic ....: This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible.

A design note:

  • Explicitly storing the quantized_dimensions, even when they can be derived for ranked tensor.
    While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked
    data tensors (ref for background), there are cases where this can lead to ambiguity and issues with round-tripping.
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>

The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome!

PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.

@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-quant

Author: Sandeep Dasgupta (sdasgup3)

Changes

This is an implementation for RFC: Supporting Sub-Channel Quantization in MLIR.

In order to make the review process easier, the PR has been divided into the following commit labels:

  1. Add implementation for sub-channel type: Includes the class design for UniformQuantizedSubChannelType, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered.
  2. Add implementation for sub-channel type: Lowering of quant.qcast and quant.dcast operations to Linalg operations.
  3. Adding C/Python Apis: We first define he C-APIs and build the Python-APIs on top of those.
  4. Add pass to normalize generic ....: This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible.

A design note:

  • Explicitly storing the quantized_dimensions, even when they can be derived for ranked tensor.
    While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked
    data tensors (ref for background), there are cases where this can lead to ambiguity and issues with round-tripping.
Consider the example: tensor&lt;2x4x!quant.uniform&lt;i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}&gt;&gt;

The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome!

PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.


Patch is 127.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120172.diff

25 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Quant.h (+41)
  • (modified) mlir/include/mlir/Dialect/Quant/IR/QuantBase.td (+183-9)
  • (modified) mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td (+21-9)
  • (modified) mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h (+131)
  • (modified) mlir/include/mlir/Dialect/Quant/Transforms/Passes.td (+33)
  • (modified) mlir/lib/Bindings/Python/DialectQuant.cpp (+74)
  • (modified) mlir/lib/CAPI/Dialect/Quant.cpp (+56)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp (+1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantOps.cpp (+123-24)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+119-2)
  • (modified) mlir/lib/Dialect/Quant/IR/TypeDetail.h (+122)
  • (modified) mlir/lib/Dialect/Quant/IR/TypeParser.cpp (+278-40)
  • (modified) mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (+197-84)
  • (added) mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp (+179)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi (+21-1)
  • (modified) mlir/test/CAPI/quant.c (+126)
  • (modified) mlir/test/Dialect/Quant/Bytecode/types.mlir (+9)
  • (modified) mlir/test/Dialect/Quant/invalid.mlir (+68)
  • (modified) mlir/test/Dialect/Quant/lower-quant-ops.mlir (+64)
  • (added) mlir/test/Dialect/Quant/normalize-quant-types.mlir (+51)
  • (modified) mlir/test/Dialect/Quant/ops.mlir (+19)
  • (modified) mlir/test/Dialect/Quant/parse-uniform-invalid.mlir (+95-5)
  • (modified) mlir/test/Dialect/Quant/parse-uniform.mlir (+18)
  • (modified) mlir/test/python/dialects/quant.py (+46)
diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h
index a7d98dc3c1a775..dc0989e53344ea 100644
--- a/mlir/include/mlir-c/Dialect/Quant.h
+++ b/mlir/include/mlir-c/Dialect/Quant.h
@@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
 MLIR_CAPI_EXPORTED bool
 mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===---------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is a UniformQuantizedSubChannel.
+MLIR_CAPI_EXPORTED bool
+mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
+
+/// Creates a UniformQuantizedSubChannelType with the given parameters.
+///
+/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
+/// DenseElementsAttrs.  `quantizedDimensions` and `blockSizes`
+/// point to `blockSizeInfoLength` number of elements, describing respectively
+/// the quantization axis and corresponding block size.
+MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
+    unsigned flags, MlirType storageType, MlirType expressedType,
+    MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
+    intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
+    int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
+
+/// Returns the number of block sizes provided in type.
+MLIR_CAPI_EXPORTED intptr_t
+mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
+
+/// Returns the quantized dimension at the given position.
+MLIR_CAPI_EXPORTED int32_t
+mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
+                                                        intptr_t pos);
+
+/// Returns the block size at the given position.
+MLIR_CAPI_EXPORTED int64_t
+mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);
+
+/// Returns the scales of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);
+
+/// Returns the zero-points of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
+
 //===---------------------------------------------------------------------===//
 // CalibratedQuantizedType
 //===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index 791cb9de48d058..0d97889960019c 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
     encodes the necessary information for (lossy) round-trip conversion between
     an expressed and a stored value.
 
-    The `quant.uniform` type has two variants: per-layer quantization and
-    per-channel (or per-axis) quantization. In per-layer quantization, the
-    quantization information affects an entire tensor uniformly. Conversely, in
-    per-channel quantization, the data type encodes the specific tensor axis
-    that serves as the channel and includes quantization information for each
-    individual channel within the tensor. Below are the specific syntactic and
-    semantic considerations for each modality.
+    The `quant.uniform` type has three variants: per-layer quantization,
+    per-channel (or per-axis) quantization, and sub-channel (or blockwize)
+    quantization.  In per-layer quantization, the quantization information
+    affects an entire tensor uniformly. Conversely, in per-channel
+    quantization, the data type encodes the specific tensor axis that serves
+    as the channel and includes quantization information for each individual
+    channel within the tensor. Sub-channel quantization is a generalization
+    of per-tensor and per-channel quantization, where the quantization
+    parameters are defined for blocks of elements along one or more
+    dimensions of the tensor. Below are the specific syntactic and semantic
+    considerations for each modality.
 
 
     ### Per-layer quantization
@@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
     ```
     // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
     // floats. Dimension 1 of the tensor acts as the channel dimension. Its
-    // size 3 matches the number of provided scale values. Tensor elemenets at
+    // size 3 matches the number of provided scale values. Tensor elements at
     // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
     // 5.0, respectively.
     tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
@@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
     tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
     ```
 
+    ### Sub-channel quantization
+
+    Sub-channel quantization, also known as blockwise quantization, provides
+    finer-grained control than per-tensor or per-channel quantization. It
+    divides a tensor into blocks of elements, each with its own quantization
+    parameters (scale and zero point). This is particularly useful when
+    different regions of a tensor exhibit distinct value ranges.
+
+    The `!quant.uniform` type represents sub-channel quantization with the
+    following syntax:
+
+    ```
+    `!quant.uniform` `<`
+      storedType (`<` storageMin `:` storageMax `>`)? `:`
+      expressedType `:` blockSizeInfo
+      scaleZeroTensor `>`
+
+    blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
+    axisBlock ::= axis `:` blockSize
+    scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
+    scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
+    scaleZeroList  ::= scaleZero (`,` scaleZero)*
+    scaleZero ::= scale (`:` zeroPoint)?
+    
+    scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
+    scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
+    scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
+    ```
+
+    The `blockSize` field specifies the size of the blocks along dimension
+    `axis` of the tensor. The `scale` and `zeroPoint` fields specify the
+    quantization parameters for a particular block. Specifically, the tensor
+    element at position [i0...iN] uses
+    `scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
+    `scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
+    and zeroPoint respectively.
+
+    Here are some examples:
+
+    ```
+    // A 3x4 tensor of i8 values representing f32 values, quantized 
+    // along axis-0 and axis-1 with block sizes 1 and 2,
+    // respectively. As a result, the shape of the scales (or zero-points) will
+    // be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
+    // blocks along each axis. Tensor elements at positions 
+    // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+    // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+    // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+    // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+    // [2][0] and [2][1] use scale `s20` and zero point `z20`,
+    // [2][2] and [2][3] use scale `s21` and zero point `z21`,
+    tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
+      {{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>
+
+    // A 2D dynamically sized tensor contains u16 values
+    // representing f32 values. Since the shape of the quantization
+    // parameters (i.e. scales and zero-points) is given as [2,2] and
+    // the blocks-sizes are given as [1,2], the shape of the tensor is expected
+    // to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
+    // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+    // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+    // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+    // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+    tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
+      {{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
+    ```
 
     ## Per-axis quantization integrity
 
@@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
     respected in any context in which the `!quant.uniform` data type is used,
     such as the header of a `func.func` op, or the input of an arithmetic
     operation.
- 
+
     - A quantized type with per-channel quantization information must be the
       element type of a tensor container type, and may not occur directly as
       the data type of a scalar value.
@@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
     // Correct. The quantized type now includes 3 scale values, matching the
     // size of dimension 1 of the result tensor.
     %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+
+    ## Sub-channel quantization integrity
+
+    When type `!quant.uniform` contains sub-channel quantization information,
+    the following rules are enforced.  For efficiency, these rules are actively
+    enforced by the verifiers of `quant` dialect ops, but they must be
+    respected in any context in which the `!quant.uniform` data type is used,
+    such as the header of a `func.func` op, or the input of an arithmetic
+    operation.
+
+    - A quantized type with sub-channel quantization information must be the
+      element type of a tensor container type, and may not occur directly as
+      the data type of a scalar value.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+    // scalar type.
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
+
+    // Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
+    // in a `tensor` type.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+    ```
+
+    - The tensor containing the sub-channel quantized type must be ranked.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+    // unranked tensor type.
+    %result = quant.qcast %input : tensor<*xf32> to
+                tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+    ```
+
+    - The axis for which a block size is specified should be valid for a tensor
+    of a given rank. Block sizes can be specified for a subset of axes. 
+    Any unspecified block size for an axis i defaults to the tensor dimension
+    size of that axis (shape(tensor)[i]).
+
+    ```
+    // Incorrect. The block-size is specified for axis 2 which is greater than
+    // the rank of the tensor.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
+
+    // Incorrect. The block-size is specified for a negative axis.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
+
+    // Correct. The block size for axis 1 is skipped which should be assumed as
+    // 2, the dim-size of tensor at axis 1.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
+
+    // Correct. The block size for all the axes are skipped making the
+    // sub-channel type essentially a per-tensor type.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
+    ```
+
+    - Block size for a particular axis should be a positive integer and should
+      be less than the dimension size of the tensor along that axis.
+
+    ```
+    // Incorrect. The block size for axis 0 is -1.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
+
+    // Incorrect. The block size for axis 0 is 8 which is greater than the
+    // dimension size of tensor at axis 0 (which is 6).
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
+
+    // Correct. The block size for axis 0 is now 3.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+    ```
+
+    - shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
+      axis i in [0, 1, ..., rank(tensor)-1]].
+
+    ```
+    // Incorrect. The block size for axis 0 is 4 and the corresponding
+    // dimension size is 6 and 6 % 4 != 0.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
+
+    // Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+    ```
+
+    - shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
+
+    ```
+    // Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
+    // shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
+
+    // Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
+    // shape(scales) equals [6,2]/[3,2].
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
     ```
   }];
   let cppNamespace = "::mlir::quant";
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
index bd9cdf82382275..8c74dbef5d94a3 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
@@ -13,6 +13,7 @@
 #ifndef QUANT_BYTECODE
 #define QUANT_BYTECODE
 
+include "mlir/IR/BuiltinDialectBytecode.td"
 include "mlir/IR/BytecodeBase.td"
 
 def DoubleAPFloat:
@@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type
   }];
 }
 
+def UniformQuantizedSubChannelType
+    : DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
+          SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
+          Array<SignedVarIntList>:$quantizedDimensions,
+          Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
+          DenseElementsAttr:$zeroPoints)> {
+  // Note: builder order differs from bytecode.
+  let cBuilder = [{
+      get<$_resultType>(context, flags, storageType, expressedType, scales,
+        zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
+        [](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
+        storageTypeMin, storageTypeMax)
+  }];
+}
+
 /// This enum contains marker codes used to indicate which attribute is
 /// currently being decoded, and how it should be decoded. The order of these
 /// codes should generally be unchanged, as any changes will inevitably break
 /// compatibility with older bytecode.
 
 def QuantDialectTypes : DialectTypes<"Quant"> {
-  let elems = [
-    ReservedOrDead,
-    AnyQuantizedType,
-    AnyQuantizedTypeWithExpressedType,
-    CalibratedQuantizedType,
-    UniformQuantizedType,
-    UniformQuantizedPerAxisType
-  ];
+  let elems = [ReservedOrDead, AnyQuantizedType,
+               AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
+               UniformQuantizedType, UniformQuantizedPerAxisType,
+               UniformQuantizedSubChannelType];
 }
 
-#endif // QUANT_BYTECODE
\ No newline at end of file
+#endif // QUANT_BYTECODE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 43440ba623b9c1..44062fe376ec0d 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -23,6 +23,7 @@ namespace detail {
 
 struct QuantizedTypeStorage;
 struct AnyQuantizedTypeStorage;
+struct UniformQuantizedSubChannelTypeStorage;
 struct UniformQuantizedTypeStorage;
 struct UniformQuantizedPerAxisTypeStorage;
 struct CalibratedQuantizedTypeStorage;
@@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType
   }
 };
 
+/// Represents sub-channel (also known as blockwise quantization).
+///
+/// Syntax synopsis:
+///   UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
+///       storageType ('<' storageMin ':' storageMax '>')? ':'
+///       expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
+///   BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
+///   AxisBlock ::= AxisSpec ':' BlockSizeSpec
+///   ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
+///   ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
+///   ScaleZeroList  ::= ScaleZero (',' ScaleZero)*
+///   ScaleZero ::= Scale (':' ZeroPoint)?
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   AxisSpec: An integer value
+///   BlockSizeSpec: An integer value
+///   Scale: An attribute (usually floating-point value)
+///   ZeroPoint: An attribute (usually integer value)
+class UniformQuantizedSubChannelType
+    : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
+                            detail::UniformQuantizedSubChannelTypeStorage> {
+public:
+  using Base::Base;
+  using Base::getChecked;
+
+  static constexpr StringLiteral name = "quant.uniform_sub_channel";
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedSubChannelType
+  get(unsigned flags, Type storageType, Type expressedType,
+      DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+      ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedSubChannelType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, DenseElementsAttr scales,
+             DenseElementsAttr zeroPoints,
+             ArrayRef<int32_t> quantizedDimensions,
+             ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+             int64_t storageTypeMax);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+                   ArrayRef<int32_t> quantizedDimensions,
+                   ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
+
+  /// Gets the quantization scales. The scales are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the scales tensor
+  /// is determined by the number of blocks along the corresponding dimension in
+  /// the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
+  /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The scale value for a specific element in the quantized data tensor at
+  /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
+  /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
+  DenseElementsAttr getScales() const;
+
+  /// Gets the quantization zero-points. The zero-points are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the zero-point
+  /// tensor is determined by the number of blocks along the corresponding
+  /// dimension in the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
+  /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The zero-point value for a specific element in the quantized data tensor
+  /// at position [i0, i1, ..., iR-1] is determined by accessing the
+  /// c...
[truncated]

@sdasgup3 sdasgup3 changed the title Subchannel quant impl Sub-channel quantized type implementation Dec 17, 2024
Copy link

github-actions bot commented Dec 17, 2024

✅ With the latest revision this PR passed the Python code formatter.

@sdasgup3 sdasgup3 force-pushed the subchannel-quant-impl branch from 3a6a0a7 to 0f4147e Compare December 17, 2024 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants