-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Matrix][IR] Cap stride bitwidth at 64 #163729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Matrix][IR] Cap stride bitwidth at 64 #163729
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-ir Author: Nathan Corbyn (cofibrant) Changesa1ef81d added overloads for CC @fhahn Full diff: https://github.com/llvm/llvm-project/pull/163729.diff 3 Files Affected:
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index c79a95087dbdd..6f38020cdd33b 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6479,9 +6479,14 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
NumRows->getZExtValue() * NumColumns->getZExtValue(),
"Result of a matrix operation does not fit in the returned vector!");
- if (Stride)
- Check(Stride->getZExtValue() >= NumRows->getZExtValue(),
+ if (Stride) {
+ // Stride can occupy an arbitrary bit-width, while rows and columns are
+ // always 32-bit, so zero extend to the largest common bit-width to
+ // compare.
+ unsigned BitWidth = std::max(Stride->getBitWidth(), NumRows->getBitWidth());
+ Check(Stride->getValue().zext(BitWidth).uge(NumRows->getValue().zext(BitWidth)),
"Stride must be greater or equal than the number of rows!", IF);
+ }
break;
}
diff --git a/llvm/test/Verifier/matrix-intrinsics-strides.ll b/llvm/test/Verifier/matrix-intrinsics-strides.ll
new file mode 100644
index 0000000000000..5ba324eebe090
--- /dev/null
+++ b/llvm/test/Verifier/matrix-intrinsics-strides.ll
@@ -0,0 +1,29 @@
+; RUN: opt %s -p verify -S -disable-output
+
+; This test ensures that verifier correctly handles very wide and very narrows
+; strides.
+
+define <4 x float> @column.major_load_stride_i8(ptr %m, i32 %arg) {
+ %result.1 = call <4 x float> @llvm.matrix.column.major.load.v4f32.i128(ptr %m, i8 16, i1 false, i32 2, i32 2)
+ ret <4 x float> %result.1
+}
+
+define <4 x float> @column.major_load_stride_i128(ptr %m, i32 %arg) {
+ %result.1 = call <4 x float> @llvm.matrix.column.major.load.v4f32.i128(ptr %m, i128 u0x10000000000000000, i1 false, i32 2, i32 2)
+ ret <4 x float> %result.1
+}
+
+define void @column.major_store_stride_i8(ptr %m, i64 %arg) {
+ call void @llvm.matrix.column.major.store.v4f32.i128(<4 x float> zeroinitializer, ptr %m, i8 16, i1 false, i32 2, i32 2)
+ ret void
+}
+
+define void @column.major_store_stride_i128(ptr %m, i64 %arg) {
+ call void @llvm.matrix.column.major.store.v4f32.i128(<4 x float> zeroinitializer, ptr %m, i128 u0x10000000000000000, i1 false, i32 2, i32 2)
+ ret void
+}
+
+declare <6 x float> @llvm.matrix.column.major.load.v6f32.i8(ptr, i8, i1, i32, i32)
+declare void @llvm.matrix.column.major.store.v4p0.i8(<4 x ptr>, ptr, i8, i1, i32, i32)
+declare <6 x float> @llvm.matrix.column.major.load.v6f32.i128(ptr, i64, i1, i32, i32)
+declare void @llvm.matrix.column.major.store.v4p0.i128(<4 x ptr>, ptr, i64, i1, i32, i32)
diff --git a/llvm/test/Verifier/matrix-intrinsics.ll b/llvm/test/Verifier/matrix-intrinsics.ll
index b6d5ad9a3cc49..e208d47c1d88e 100644
--- a/llvm/test/Verifier/matrix-intrinsics.ll
+++ b/llvm/test/Verifier/matrix-intrinsics.ll
@@ -1,8 +1,7 @@
-; RUN: not llvm-as < %s -o /dev/null 2>&1 | FileCheck %s
+; RUN: not opt -S %s -p verify 2>&1 | FileCheck %s
define <4 x float> @transpose(<4 x float> %m, i32 %arg) {
-; CHECK: assembly parsed, but does not verify as correct!
-; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
+; CHECK: Result of a matrix operation does not fit in the returned vector!
; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
; CHECK-NEXT: immarg operand has non-immediate parameter
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
b4b1573
to
8436cb4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively you could also verify that the bit width is <= 64, which might save additional handling elsewhere. IIRC you don't need more than 64 bits as long as that's the maximum size_t type. No strong opinion though.
8436cb4
to
3b06181
Compare
I will defer to @fhahn on this |
3b06181
to
4d7ea22
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine to restrict the stride to be <= 64 bits
4d7ea22
to
8bfaff3
Compare
Ok that's a simpler change, I've updated the PR to reflect this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also update LangRef together with the verifier change
8bfaff3
to
1594632
Compare
1594632
to
a0f9895
Compare
|
||
if (Stride) | ||
if (Stride) { | ||
Check(Stride->getBitWidth() <= 64, "Stride bitwidth cannot exceed 64!", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks fine to me. Are there any cases where it could be less than 64 where we might want this to be configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in practice on targets with 32 bit index width (or lower), it may be smaller, matching the index width
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understood Farzon's question as asking whether we might want the cap to be configurable so that, e.g., we can reject 64-bit strides when verifying IR targeting 32-bit platforms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The restriction in the verifier is mostly just so we can use getZExtValue() on it.
Regardless of the type width, stride, stride, columns, rows must be such that we can access all data without wrapping the address space. Restricting to i32 on 32 bit targets wouldn't really help to enforce that, as it would still be possible to provide arguments that would cause the accesses to wrap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks!
|
||
if (Stride) | ||
if (Stride) { | ||
Check(Stride->getBitWidth() <= 64, "Stride bitwidth cannot exceed 64!", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in practice on targets with 32 bit index width (or lower), it may be smaller, matching the index width
a0f9895
to
a455a6c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Test failure is unrelated #163937 |
a1ef81d added overloads for `llvm.matrix.column.major.store` and `llvm.matrix.column.major.load` that allow strides to occupy an arbitrary bitwidth. This change wasn't reflected in the verifier, causing an assertion to trip when given strides overflowing 64-bit. This patch explicitly caps the bitwidth at 64, repairing the crash and avoiding future complexity dealing with strides that overflow 64 bits. PR: llvm/llvm-project#163729
a1ef81d added overloads for
llvm.matrix.column.major.store
andllvm.matrix.column.major.load
that allow strides to occupy an arbitrary bitwidth. This change wasn't reflected in the verifier, causing an assertion to trip when given strides overflowing 64-bit. This patch explicitly caps the bitwidth at 64, repairing the crash and avoiding future complexity dealing with strides that overflow 64 bits.CC @fhahn