Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spec the behavior of precision_config in DotGeneralOp #755

Open
burmako opened this issue Dec 13, 2022 · 6 comments
Open

Spec the behavior of precision_config in DotGeneralOp #755

burmako opened this issue Dec 13, 2022 · 6 comments
Assignees

Comments

@burmako
Copy link
Contributor

burmako commented Dec 13, 2022

See #307 (comment).

@ghpvnist
Copy link
Member

Here are the results of my investigation:
On the JAX side, this is what the enum values mean:

  • DEFAULT: uses bf16 for computation with accumulation done in at least the precision of the output type, for which preferred_element_type (Consider adding preferred_element_type to DotGeneralOp/ConvolutionOp #600) can override, in which case depending on whether the precision is larger or smaller, the accumulation can have more/less precision, respectively.
  • HIGH: uses 3 bf16 passes or tf32 (8 exponent, 10 mantissa bits) type if available for computation. The accumulation is done in a similar fashion.
  • HIGHEST: uses f32 or f64 as applicable for computation. The accumulation is done in a similar fashion.

For StableHLO, we currently do not have preferred_element_type as an (optional) attribute and rely on the result element type. This means that we always have a defined preferred_element_type, so accumulation would always happen using the result element type. Also, we do not have tf32 or model passes to calculate multiplication by splitting the bits (as this is more hardware specific). Therefore, I propose we use the following to do computation and accumulation based on the precison_config values:

  • DEFAULT: use bf16 for computation, accumulating values using the result element type as preferred_element_type.
  • HIGH: use f32 for computation, accumulating values using the result element type as preferred_element_type.
  • HIGHEST: use f64 for computation, accumulating values using the result element type as preferred_element_type.

In other words, the multiplication/sum (intermediate results) is calculated using bf16, f32, or f64, and the resulting sum (final value) is casted to the result element type.

@ghpvnist
Copy link
Member

ghpvnist commented Mar 24, 2023

(C1) lhs and rhs have the same element type.

Currently, (C1) does not allow lhs and rhs element types to be different. This would prevent specifying distinct precision config values for lhs and rhs which means that having two config values does not provide extra information since this behavior can be captured with just one config value. This constraint should also be removed while this issue is addressed.

Also, we should document that having different element types for lhs/rhs means we should convert/reduce precision of inputs to no less than this specified value before applying the multiplication, which is currently not documented in the spec.

@dominicsymes
Copy link

Hi. Just to add here that from #1413 discussion thread, it is useful if the accumulator range for integer accumulation in quantized operations is also specified.

@zichuan-wei
Copy link
Contributor

zichuan-wei commented May 10, 2023

Given that the inputs and outputs of an op already contain information about the precision of the expected input and output, I think it makes not so much sense to have a separate precision config, especially for on-device use cases. And DEFAULT, HIGH, & HIGHEST are kind of vague in terms of what they represent.

@sdasgup3
Copy link
Member

sdasgup3 commented May 11, 2023

Thanks @zichuan-wei for the input!
IMO, precision_config's one use case is to override the precision specified by the input parameters, allowing the choice of low precision ( and fast) vs high precision (and slow) computation. I think this param is being used in server GPU and TPUs.

And DEFAULT, HIGH, & HIGHEST are kind of vague in terms of what they represent.

I agree with you. I recently learned from @reedwm how involving the semantics of these are and I believe fixing the vagueness is a priority for this ticket.

ghpvnist added a commit that referenced this issue Jun 9, 2023
We have the following non-quantization-related constraints (excluding
C13, C15-C20) in the spec:

```
(I1) lhs tensor.
(I2) rhs tensor.
(I3) lhs_batching_dimensions 1-dimensional tensor constant of type `si64`.
(I4) rhs_batching_dimensions 1-dimensional tensor constant of type `si64`.
(I5) lhs_contracting_dimensions 1-dimensional tensor constant of type `si64`.
(I6) rhs_contracting_dimensions 1-dimensional tensor constant of type `si64`.
(I7) precision_config variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST`.
(C1) size(`lhs_batching_dimensions`) = size(`rhs_batching_dimensions`).
(C2) size(`lhs_contracting_dimensions`) =
size(`rhs_contracting_dimensions`).
(C3) `lhs_batching_dimensions` and `lhs_contracting_dimensions` combined are
unique.
(C4) `rhs_batching_dimensions` and `rhs_contracting_dimensions` combined are
unique.
(C5) 0 <= `lhs_batching_dimensions[i]` < rank(`lhs`) for all `i`
in [0, size(`lhs_batching_dimensions`)).
(C6) 0 <= `lhs_contracting_dimensions[i]` < rank(`lhs`) for all `i`
in [0, size(`lhs_contracting_dimensions`)).
(C7) 0 <= `rhs_batching_dimensions[i]` < rank(`rhs`) for all `i`
in [0, size(`rhs_batching_dimensions`)).
(C8) 0 <= `rhs_contracting_dimensions[i]` < rank(`rhs`) for all `i`
in [0, size(`rhs_contracting_dimensions`)).
(C9) dim(`lhs`, `lhs_batching_dimensions[i]`) =
dim(`rhs`, `rhs_batching_dimensions[i]`) for all `i` in [0,
size(`lhs_batching_dimensions`)).
(C10) dim(`lhs`, `lhs_contracting_dimensions[i]`) =
dim(`rhs`, `rhs_contracting_dimensions[i]`) for all `i` in [0,
size(`lhs_contracting_dimensions`)).
(C11) size(`precision_config`) = 2.
(C12) shape(`result`) = dim(`lhs`, `lhs_batching_dimensions`) +
dim(`lhs`, `lhs_result_dimensions`) + dim(`rhs`, `rhs_result_dimensions`).
(C14) element_type(`lhs`) = element_type(`rhs`).
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) lhs is not a tensor. (Covered by ODS).
I2: a) rhs is not a tensor. (Covered by ODS).
I3: a) lhs_batching_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(lhs_batching_dimesnions) != `si64`. (Covered by ODS).
I4: a) rhs_batching_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(rhs_batching_dimesnions) != `si64`. (Covered by ODS).
I5: a) lhs_contracting_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(lhs_contracting_dimensions) != `si64`. (Covered by ODS).
I6: a) rhs_contracting_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(rhs_contracting_dimensions) != `si64`. (Covered by ODS).
I7: a) precision_config does not have variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST`. (Covered by ODS).
C1: a) size(lhs_batching_dimensions) != size(rhs_batching_dimensions).
C2: a) size(lhs_contracting_dimensions) != size(rhs_contracting_dimensions).
C3: a) lhs_batching_dimensions and lhs_contracting_dimensions combined are not unique.
C4: a) rhs_batching_dimensions and rhs_contracting_dimensions combined are not unique.
C5: a) lhs_batching_dimensions[i] < 0 for any i.
    b) lhs_batching_dimensions[i] >= rank(lhs) for any i.
C6: a) lhs_contracting_dimensions[i] < 0 for any i.
    b) lhs_contracting_dimensions[i] >= rank(lhs) for any i.
C7: a) rhs_batching_dimensions[i] < 0 for any i.
    b) rhs_batching_dimensions[i] >= rank(rhs) for any i.
C8: a) rhs_contracting_dimensions[i] < 0 for any i.
    b) rhs_contracting_dimensions[i] >= rank(rhs) for any i.
C9: a) dim(lhs, lhs_batching_dimensions[i]) != dim(rhs, rhs_batching_dimensions[i]) for any i.
C10: a) dim(lhs, lhs_contracting_dimensions[i]) != dim(rhs, rhs_contracting_dimensions[i]) for any i.
C11: a) size(precision_config) != 2.
C12: no negative test needed since it's just inferring the shape.
C14: a) element_type(lhs) != element_type(rhs).
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
C1a: size(lhs_batching_dimensions) != size(rhs_batching_dimensions).
C2a: size(lhs_contracting_dimensions) != size(rhs_contracting_dimensions).
C3a: lhs_batching_dimensions and lhs_contracting_dimensions combined are not unique.
C4a: rhs_batching_dimensions and rhs_contracting_dimensions combined are not unique.
C5a: lhs_batching_dimensions[i] < 0 for any i.
C5b: lhs_batching_dimensions[i] >= rank(lhs) for any i.
C6a: lhs_contracting_dimensions[i] < 0 for any i.
C6b: lhs_contracting_dimensions[i] >= rank(lhs) for any i.
C7a: rhs_batching_dimensions[i] < 0 for any i.
C7b: rhs_batching_dimensions[i] >= rank(rhs) for any i.
C8a: rhs_contracting_dimensions[i] < 0 for any i.
C8b: rhs_contracting_dimensions[i] >= rank(rhs) for any i.
C9a: dim(lhs, lhs_batching_dimensions[i]) != dim(rhs, rhs_batching_dimensions[i]) for any i.
C10a: dim(lhs, lhs_contracting_dimensions[i]) != dim(rhs, rhs_contracting_dimensions[i]) for any i.
C11a: size(precision_config) != 2.
C14a: element_type(lhs) != element_type(rhs).
```

Notes:
* (C14) currently does not have a test and consider removing it #755
* (C11) currently does not have a test due to #755 and #879.

closes #336
@sdasgup3
Copy link
Member

sdasgup3 commented Mar 4, 2024

AI: precision_config and quantization do not go hand in hand and there behavior should be defined when when co-exist (for example, convolution with precision_config).

context: #2050 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Status: Backlog
Development

No branches or pull requests

5 participants