Skip to content

Commit 2b858f5

Browse files
committed
8328938: C2 SuperWord: disable vectorization for large stride and scale
Reviewed-by: epeter, simonis Backport-of: 2931458711244e20eb7845a1aefcf6ed4206bce1
1 parent 9159882 commit 2b858f5

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed

src/hotspot/share/opto/superword.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4200,6 +4200,25 @@ SWPointer::SWPointer(MemNode* mem, SuperWord* slp, Node_Stack *nstack, bool anal
42004200
NOT_PRODUCT(if(_slp->is_trace_alignment()) _tracer.restore_depth();)
42014201
NOT_PRODUCT(_tracer.ctor_6(mem);)
42024202

4203+
// In the pointer analysis, and especially the AlignVector, analysis we assume that
4204+
// stride and scale are not too large. For example, we multiply "scale * stride",
4205+
// and assume that this does not overflow the int range. We also take "abs(scale)"
4206+
// and "abs(stride)", which would overflow for min_int = -(2^31). Still, we want
4207+
// to at least allow small and moderately large stride and scale. Therefore, we
4208+
// allow values up to 2^30, which is only a factor 2 smaller than the max/min int.
4209+
// Normal performance relevant code will have much lower values. And the restriction
4210+
// allows us to keep the rest of the autovectorization code much simpler, since we
4211+
// do not have to deal with overflows.
4212+
jlong long_scale = _scale;
4213+
jlong long_stride = slp->lp()->stride_is_con() ? slp->iv_stride() : 0;
4214+
jlong max_val = 1 << 30;
4215+
if (abs(long_scale) >= max_val ||
4216+
abs(long_stride) >= max_val ||
4217+
abs(long_scale * long_stride) >= max_val) {
4218+
assert(!valid(), "adr stride*scale is too large");
4219+
return;
4220+
}
4221+
42034222
_base = base;
42044223
_adr = adr;
42054224
assert(valid(), "Usable");
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
/*
2+
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
24+
/*
25+
* @test id=vanilla
26+
* @bug 8328938
27+
* @summary Test autovectorization with large scale and stride
28+
* @modules java.base/jdk.internal.misc
29+
* @library /test/lib /
30+
* @run main compiler.loopopts.superword.TestLargeScaleAndStride
31+
*/
32+
33+
/*
34+
* @test id=AlignVector
35+
* @bug 8328938
36+
* @modules java.base/jdk.internal.misc
37+
* @library /test/lib /
38+
* @requires vm.compiler2.enabled
39+
* @run main/othervm -XX:+AlignVector compiler.loopopts.superword.TestLargeScaleAndStride
40+
*/
41+
42+
package compiler.loopopts.superword;
43+
44+
import jdk.internal.misc.Unsafe;
45+
46+
public class TestLargeScaleAndStride {
47+
static final Unsafe UNSAFE = Unsafe.getUnsafe();
48+
static int RANGE = 100_000;
49+
50+
public static void main(String[] args) {
51+
byte[] a = new byte[100];
52+
fill(a);
53+
54+
byte[] gold1a = a.clone();
55+
byte[] gold1b = a.clone();
56+
byte[] gold2a = a.clone();
57+
byte[] gold2b = a.clone();
58+
byte[] gold2c = a.clone();
59+
byte[] gold2d = a.clone();
60+
byte[] gold3 = a.clone();
61+
test1a(gold1a);
62+
test1b(gold1b);
63+
test2a(gold2a);
64+
test2b(gold2b);
65+
test2c(gold2c);
66+
test2d(gold2d);
67+
test3(gold3);
68+
69+
for (int i = 0; i < 100; i++) {
70+
byte[] c = a.clone();
71+
test1a(c);
72+
verify(c, gold1a);
73+
}
74+
75+
for (int i = 0; i < 100; i++) {
76+
byte[] c = a.clone();
77+
test1b(c);
78+
verify(c, gold1b);
79+
}
80+
81+
for (int i = 0; i < 100; i++) {
82+
byte[] c = a.clone();
83+
test2a(c);
84+
verify(c, gold2a);
85+
}
86+
87+
for (int i = 0; i < 100; i++) {
88+
byte[] c = a.clone();
89+
test2b(c);
90+
verify(c, gold2b);
91+
}
92+
93+
for (int i = 0; i < 100; i++) {
94+
byte[] c = a.clone();
95+
test2c(c);
96+
verify(c, gold2c);
97+
}
98+
99+
for (int i = 0; i < 100; i++) {
100+
byte[] c = a.clone();
101+
test2d(c);
102+
verify(c, gold2d);
103+
}
104+
105+
for (int i = 0; i < 100; i++) {
106+
byte[] c = a.clone();
107+
test3(c);
108+
verify(c, gold3);
109+
}
110+
}
111+
112+
static void fill(byte[] a) {
113+
for (int i = 0; i < a.length; i++) {
114+
a[i] = (byte)i;
115+
}
116+
}
117+
118+
static void verify(byte[] a, byte[] b) {
119+
for (int i = 0; i < a.length; i++) {
120+
if (a[i] != b[i]) {
121+
throw new RuntimeException("wrong value: " + i + ": " + a[i] + " != " + b[i]);
122+
}
123+
}
124+
}
125+
126+
static void test1a(byte[] a) {
127+
int scale = 1 << 31;
128+
for (int i = 0; i < RANGE; i+=2) {
129+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
130+
// i is a multiple of 2
131+
// 2 * (1 >> 31) -> overflow to zero
132+
int j = scale * i; // always zero
133+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
134+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
135+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
136+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
137+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
138+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
139+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
140+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
141+
}
142+
}
143+
144+
static void test1b(byte[] a) {
145+
int scale = 1 << 31;
146+
for (int i = RANGE-2; i >= 0; i-=2) {
147+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
148+
// i is a multiple of 2
149+
// 2 * (1 >> 31) -> overflow to zero
150+
int j = scale * i; // always zero
151+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
152+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
153+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
154+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
155+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
156+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
157+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
158+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
159+
}
160+
}
161+
162+
static void test2a(byte[] a) {
163+
int scale = 1 << 30;
164+
for (int i = 0; i < RANGE; i+=4) {
165+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
166+
// i is a multiple of 4
167+
// 4 * (1 >> 30) -> overflow to zero
168+
int j = scale * i; // always zero
169+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
170+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
171+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
172+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
173+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
174+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
175+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
176+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
177+
}
178+
}
179+
180+
181+
static void test2b(byte[] a) {
182+
int scale = 1 << 30;
183+
for (int i = RANGE-4; i >= 0; i-=4) {
184+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
185+
// i is a multiple of 4
186+
// 4 * (1 >> 30) -> overflow to zero
187+
int j = scale * i; // always zero
188+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
189+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
190+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
191+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
192+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
193+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
194+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
195+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
196+
}
197+
}
198+
199+
static void test2c(byte[] a) {
200+
int scale = -(1 << 30);
201+
for (int i = 0; i < RANGE; i+=4) {
202+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
203+
// i is a multiple of 4
204+
// 4 * (1 >> 30) -> overflow to zero
205+
int j = scale * i; // always zero
206+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
207+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
208+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
209+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
210+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
211+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
212+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
213+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
214+
}
215+
}
216+
217+
static void test2d(byte[] a) {
218+
int scale = -(1 << 30);
219+
for (int i = RANGE-4; i >= 0; i-=4) {
220+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
221+
// i is a multiple of 4
222+
// 4 * (1 >> 30) -> overflow to zero
223+
int j = scale * i; // always zero
224+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
225+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
226+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
227+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
228+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
229+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
230+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
231+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
232+
}
233+
}
234+
235+
static void test3(byte[] a) {
236+
int scale = 1 << 28;
237+
int stride = 1 << 4;
238+
int start = -(1 << 30);
239+
int end = 1 << 30;
240+
for (int i = start; i < end; i+=stride) {
241+
long base = UNSAFE.ARRAY_BYTE_BASE_OFFSET;
242+
int j = scale * i; // always zero
243+
byte v0 = UNSAFE.getByte(a, base + (int)(j + 0));
244+
byte v1 = UNSAFE.getByte(a, base + (int)(j + 1));
245+
byte v2 = UNSAFE.getByte(a, base + (int)(j + 2));
246+
byte v3 = UNSAFE.getByte(a, base + (int)(j + 3));
247+
UNSAFE.putByte(a, base + (int)(j + 0), (byte)(v0 + 1));
248+
UNSAFE.putByte(a, base + (int)(j + 1), (byte)(v1 + 1));
249+
UNSAFE.putByte(a, base + (int)(j + 2), (byte)(v2 + 1));
250+
UNSAFE.putByte(a, base + (int)(j + 3), (byte)(v3 + 1));
251+
}
252+
}
253+
}

0 commit comments

Comments
 (0)