-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spec the behavior of precision_config in DotGeneralOp #755
Comments
Here are the results of my investigation:
For StableHLO, we currently do not have
In other words, the multiplication/sum (intermediate results) is calculated using |
(C1) lhs and rhs have the same element type. Currently, (C1) does not allow lhs and rhs element types to be different. This would prevent specifying distinct precision config values for lhs and rhs which means that having two config values does not provide extra information since this behavior can be captured with just one config value. This constraint should also be removed while this issue is addressed. Also, we should document that having different element types for lhs/rhs means we should convert/reduce precision of inputs to no less than this specified value before applying the multiplication, which is currently not documented in the spec. |
Hi. Just to add here that from #1413 discussion thread, it is useful if the accumulator range for integer accumulation in quantized operations is also specified. |
Given that the inputs and outputs of an op already contain information about the precision of the expected input and output, I think it makes not so much sense to have a separate precision config, especially for on-device use cases. And DEFAULT, HIGH, & HIGHEST are kind of vague in terms of what they represent. |
Thanks @zichuan-wei for the input!
I agree with you. I recently learned from @reedwm how involving the semantics of these are and I believe fixing the vagueness is a priority for this ticket. |
We have the following non-quantization-related constraints (excluding C13, C15-C20) in the spec: ``` (I1) lhs tensor. (I2) rhs tensor. (I3) lhs_batching_dimensions 1-dimensional tensor constant of type `si64`. (I4) rhs_batching_dimensions 1-dimensional tensor constant of type `si64`. (I5) lhs_contracting_dimensions 1-dimensional tensor constant of type `si64`. (I6) rhs_contracting_dimensions 1-dimensional tensor constant of type `si64`. (I7) precision_config variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST`. (C1) size(`lhs_batching_dimensions`) = size(`rhs_batching_dimensions`). (C2) size(`lhs_contracting_dimensions`) = size(`rhs_contracting_dimensions`). (C3) `lhs_batching_dimensions` and `lhs_contracting_dimensions` combined are unique. (C4) `rhs_batching_dimensions` and `rhs_contracting_dimensions` combined are unique. (C5) 0 <= `lhs_batching_dimensions[i]` < rank(`lhs`) for all `i` in [0, size(`lhs_batching_dimensions`)). (C6) 0 <= `lhs_contracting_dimensions[i]` < rank(`lhs`) for all `i` in [0, size(`lhs_contracting_dimensions`)). (C7) 0 <= `rhs_batching_dimensions[i]` < rank(`rhs`) for all `i` in [0, size(`rhs_batching_dimensions`)). (C8) 0 <= `rhs_contracting_dimensions[i]` < rank(`rhs`) for all `i` in [0, size(`rhs_contracting_dimensions`)). (C9) dim(`lhs`, `lhs_batching_dimensions[i]`) = dim(`rhs`, `rhs_batching_dimensions[i]`) for all `i` in [0, size(`lhs_batching_dimensions`)). (C10) dim(`lhs`, `lhs_contracting_dimensions[i]`) = dim(`rhs`, `rhs_contracting_dimensions[i]`) for all `i` in [0, size(`lhs_contracting_dimensions`)). (C11) size(`precision_config`) = 2. (C12) shape(`result`) = dim(`lhs`, `lhs_batching_dimensions`) + dim(`lhs`, `lhs_result_dimensions`) + dim(`rhs`, `rhs_result_dimensions`). (C14) element_type(`lhs`) = element_type(`rhs`). ``` These constraints will be comprehensively covered by the following tests: ``` I1: a) lhs is not a tensor. (Covered by ODS). I2: a) rhs is not a tensor. (Covered by ODS). I3: a) lhs_batching_dimensions is not a 1-dimensional tensor. (Covered by ODS). b) element_type(lhs_batching_dimesnions) != `si64`. (Covered by ODS). I4: a) rhs_batching_dimensions is not a 1-dimensional tensor. (Covered by ODS). b) element_type(rhs_batching_dimesnions) != `si64`. (Covered by ODS). I5: a) lhs_contracting_dimensions is not a 1-dimensional tensor. (Covered by ODS). b) element_type(lhs_contracting_dimensions) != `si64`. (Covered by ODS). I6: a) rhs_contracting_dimensions is not a 1-dimensional tensor. (Covered by ODS). b) element_type(rhs_contracting_dimensions) != `si64`. (Covered by ODS). I7: a) precision_config does not have variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST`. (Covered by ODS). C1: a) size(lhs_batching_dimensions) != size(rhs_batching_dimensions). C2: a) size(lhs_contracting_dimensions) != size(rhs_contracting_dimensions). C3: a) lhs_batching_dimensions and lhs_contracting_dimensions combined are not unique. C4: a) rhs_batching_dimensions and rhs_contracting_dimensions combined are not unique. C5: a) lhs_batching_dimensions[i] < 0 for any i. b) lhs_batching_dimensions[i] >= rank(lhs) for any i. C6: a) lhs_contracting_dimensions[i] < 0 for any i. b) lhs_contracting_dimensions[i] >= rank(lhs) for any i. C7: a) rhs_batching_dimensions[i] < 0 for any i. b) rhs_batching_dimensions[i] >= rank(rhs) for any i. C8: a) rhs_contracting_dimensions[i] < 0 for any i. b) rhs_contracting_dimensions[i] >= rank(rhs) for any i. C9: a) dim(lhs, lhs_batching_dimensions[i]) != dim(rhs, rhs_batching_dimensions[i]) for any i. C10: a) dim(lhs, lhs_contracting_dimensions[i]) != dim(rhs, rhs_contracting_dimensions[i]) for any i. C11: a) size(precision_config) != 2. C12: no negative test needed since it's just inferring the shape. C14: a) element_type(lhs) != element_type(rhs). ``` If we drop the "Covered by ODS" pieces, this will leave us with the following test cases: ``` C1a: size(lhs_batching_dimensions) != size(rhs_batching_dimensions). C2a: size(lhs_contracting_dimensions) != size(rhs_contracting_dimensions). C3a: lhs_batching_dimensions and lhs_contracting_dimensions combined are not unique. C4a: rhs_batching_dimensions and rhs_contracting_dimensions combined are not unique. C5a: lhs_batching_dimensions[i] < 0 for any i. C5b: lhs_batching_dimensions[i] >= rank(lhs) for any i. C6a: lhs_contracting_dimensions[i] < 0 for any i. C6b: lhs_contracting_dimensions[i] >= rank(lhs) for any i. C7a: rhs_batching_dimensions[i] < 0 for any i. C7b: rhs_batching_dimensions[i] >= rank(rhs) for any i. C8a: rhs_contracting_dimensions[i] < 0 for any i. C8b: rhs_contracting_dimensions[i] >= rank(rhs) for any i. C9a: dim(lhs, lhs_batching_dimensions[i]) != dim(rhs, rhs_batching_dimensions[i]) for any i. C10a: dim(lhs, lhs_contracting_dimensions[i]) != dim(rhs, rhs_contracting_dimensions[i]) for any i. C11a: size(precision_config) != 2. C14a: element_type(lhs) != element_type(rhs). ``` Notes: * (C14) currently does not have a test and consider removing it #755 * (C11) currently does not have a test due to #755 and #879. closes #336
AI: context: #2050 (comment) |
See #307 (comment).
The text was updated successfully, but these errors were encountered: