Skip to content

Commit 8e41bf2

Browse files
jaskarthcl4es
authored andcommitted
8303238: Create generalizations for existing LShift ideal transforms
Reviewed-by: redestad, thartmann
1 parent 805a4e6 commit 8e41bf2

File tree

4 files changed

+425
-29
lines changed

4 files changed

+425
-29
lines changed

src/hotspot/share/opto/mulnode.cpp

+132-26
Original file line numberDiff line numberDiff line change
@@ -847,21 +847,74 @@ Node *LShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
847847
}
848848
}
849849

850-
// Check for "(x>>c0)<<c0" which just masks off low bits
851-
if( (add1_op == Op_RShiftI || add1_op == Op_URShiftI ) &&
852-
add1->in(2) == in(2) )
853-
// Convert to "(x & -(1<<c0))"
854-
return new AndINode(add1->in(1),phase->intcon( -(1<<con)));
855-
856-
// Check for "((x>>c0) & Y)<<c0" which just masks off more low bits
857-
if( add1_op == Op_AndI ) {
850+
// Check for "(x >> C1) << C2"
851+
if (add1_op == Op_RShiftI || add1_op == Op_URShiftI) {
852+
// Special case C1 == C2, which just masks off low bits
853+
if (add1->in(2) == in(2)) {
854+
// Convert to "(x & -(1 << C2))"
855+
return new AndINode(add1->in(1), phase->intcon(-(1 << con)));
856+
} else {
857+
int add1Con = 0;
858+
const_shift_count(phase, add1, &add1Con);
859+
860+
// Wait until the right shift has been sharpened to the correct count
861+
if (add1Con > 0 && add1Con < BitsPerJavaInteger) {
862+
// As loop parsing can produce LShiftI nodes, we should wait until the graph is fully formed
863+
// to apply optimizations, otherwise we can inadvertently stop vectorization opportunities.
864+
if (phase->is_IterGVN()) {
865+
if (con > add1Con) {
866+
// Creates "(x << (C2 - C1)) & -(1 << C2)"
867+
Node* lshift = phase->transform(new LShiftINode(add1->in(1), phase->intcon(con - add1Con)));
868+
return new AndINode(lshift, phase->intcon(-(1 << con)));
869+
} else {
870+
assert(con < add1Con, "must be (%d < %d)", con, add1Con);
871+
// Creates "(x >> (C1 - C2)) & -(1 << C2)"
872+
873+
// Handle logical and arithmetic shifts
874+
Node* rshift;
875+
if (add1_op == Op_RShiftI) {
876+
rshift = phase->transform(new RShiftINode(add1->in(1), phase->intcon(add1Con - con)));
877+
} else {
878+
rshift = phase->transform(new URShiftINode(add1->in(1), phase->intcon(add1Con - con)));
879+
}
880+
881+
return new AndINode(rshift, phase->intcon(-(1 << con)));
882+
}
883+
} else {
884+
phase->record_for_igvn(this);
885+
}
886+
}
887+
}
888+
}
889+
890+
// Check for "((x >> C1) & Y) << C2"
891+
if (add1_op == Op_AndI) {
858892
Node *add2 = add1->in(1);
859893
int add2_op = add2->Opcode();
860-
if( (add2_op == Op_RShiftI || add2_op == Op_URShiftI ) &&
861-
add2->in(2) == in(2) ) {
862-
// Convert to "(x & (Y<<c0))"
863-
Node *y_sh = phase->transform( new LShiftINode( add1->in(2), in(2) ) );
864-
return new AndINode( add2->in(1), y_sh );
894+
if (add2_op == Op_RShiftI || add2_op == Op_URShiftI) {
895+
// Special case C1 == C2, which just masks off low bits
896+
if (add2->in(2) == in(2)) {
897+
// Convert to "(x & (Y << C2))"
898+
Node* y_sh = phase->transform(new LShiftINode(add1->in(2), phase->intcon(con)));
899+
return new AndINode(add2->in(1), y_sh);
900+
}
901+
902+
int add2Con = 0;
903+
const_shift_count(phase, add2, &add2Con);
904+
if (add2Con > 0 && add2Con < BitsPerJavaInteger) {
905+
if (phase->is_IterGVN()) {
906+
// Convert to "((x >> C1) << C2) & (Y << C2)"
907+
908+
// Make "(x >> C1) << C2", which will get folded away by the rule above
909+
Node* x_sh = phase->transform(new LShiftINode(add2, phase->intcon(con)));
910+
// Make "Y << C2", which will simplify when Y is a constant
911+
Node* y_sh = phase->transform(new LShiftINode(add1->in(2), phase->intcon(con)));
912+
913+
return new AndINode(x_sh, y_sh);
914+
} else {
915+
phase->record_for_igvn(this);
916+
}
917+
}
865918
}
866919
}
867920

@@ -970,21 +1023,74 @@ Node *LShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
9701023
}
9711024
}
9721025

973-
// Check for "(x>>c0)<<c0" which just masks off low bits
974-
if( (add1_op == Op_RShiftL || add1_op == Op_URShiftL ) &&
975-
add1->in(2) == in(2) )
976-
// Convert to "(x & -(1<<c0))"
977-
return new AndLNode(add1->in(1),phase->longcon( -(CONST64(1)<<con)));
1026+
// Check for "(x >> C1) << C2"
1027+
if (add1_op == Op_RShiftL || add1_op == Op_URShiftL) {
1028+
// Special case C1 == C2, which just masks off low bits
1029+
if (add1->in(2) == in(2)) {
1030+
// Convert to "(x & -(1 << C2))"
1031+
return new AndLNode(add1->in(1), phase->longcon(-(CONST64(1) << con)));
1032+
} else {
1033+
int add1Con = 0;
1034+
const_shift_count(phase, add1, &add1Con);
1035+
1036+
// Wait until the right shift has been sharpened to the correct count
1037+
if (add1Con > 0 && add1Con < BitsPerJavaLong) {
1038+
// As loop parsing can produce LShiftI nodes, we should wait until the graph is fully formed
1039+
// to apply optimizations, otherwise we can inadvertently stop vectorization opportunities.
1040+
if (phase->is_IterGVN()) {
1041+
if (con > add1Con) {
1042+
// Creates "(x << (C2 - C1)) & -(1 << C2)"
1043+
Node* lshift = phase->transform(new LShiftLNode(add1->in(1), phase->intcon(con - add1Con)));
1044+
return new AndLNode(lshift, phase->longcon(-(CONST64(1) << con)));
1045+
} else {
1046+
assert(con < add1Con, "must be (%d < %d)", con, add1Con);
1047+
// Creates "(x >> (C1 - C2)) & -(1 << C2)"
1048+
1049+
// Handle logical and arithmetic shifts
1050+
Node* rshift;
1051+
if (add1_op == Op_RShiftL) {
1052+
rshift = phase->transform(new RShiftLNode(add1->in(1), phase->intcon(add1Con - con)));
1053+
} else {
1054+
rshift = phase->transform(new URShiftLNode(add1->in(1), phase->intcon(add1Con - con)));
1055+
}
1056+
1057+
return new AndLNode(rshift, phase->longcon(-(CONST64(1) << con)));
1058+
}
1059+
} else {
1060+
phase->record_for_igvn(this);
1061+
}
1062+
}
1063+
}
1064+
}
9781065

979-
// Check for "((x>>c0) & Y)<<c0" which just masks off more low bits
980-
if( add1_op == Op_AndL ) {
981-
Node *add2 = add1->in(1);
1066+
// Check for "((x >> C1) & Y) << C2"
1067+
if (add1_op == Op_AndL) {
1068+
Node* add2 = add1->in(1);
9821069
int add2_op = add2->Opcode();
983-
if( (add2_op == Op_RShiftL || add2_op == Op_URShiftL ) &&
984-
add2->in(2) == in(2) ) {
985-
// Convert to "(x & (Y<<c0))"
986-
Node *y_sh = phase->transform( new LShiftLNode( add1->in(2), in(2) ) );
987-
return new AndLNode( add2->in(1), y_sh );
1070+
if (add2_op == Op_RShiftL || add2_op == Op_URShiftL) {
1071+
// Special case C1 == C2, which just masks off low bits
1072+
if (add2->in(2) == in(2)) {
1073+
// Convert to "(x & (Y << C2))"
1074+
Node* y_sh = phase->transform(new LShiftLNode(add1->in(2), phase->intcon(con)));
1075+
return new AndLNode(add2->in(1), y_sh);
1076+
}
1077+
1078+
int add2Con = 0;
1079+
const_shift_count(phase, add2, &add2Con);
1080+
if (add2Con > 0 && add2Con < BitsPerJavaLong) {
1081+
if (phase->is_IterGVN()) {
1082+
// Convert to "((x >> C1) << C2) & (Y << C2)"
1083+
1084+
// Make "(x >> C1) << C2", which will get folded away by the rule above
1085+
Node* x_sh = phase->transform(new LShiftLNode(add2, phase->intcon(con)));
1086+
// Make "Y << C2", which will simplify when Y is a constant
1087+
Node* y_sh = phase->transform(new LShiftLNode(add1->in(2), phase->intcon(con)));
1088+
1089+
return new AndLNode(x_sh, y_sh);
1090+
} else {
1091+
phase->record_for_igvn(this);
1092+
}
1093+
}
9881094
}
9891095
}
9901096

test/hotspot/jtreg/compiler/c2/irTests/LShiftINodeIdealizationTests.java

+63-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -27,7 +27,7 @@
2727

2828
/*
2929
* @test
30-
* @bug 8297384
30+
* @bug 8297384 8303238
3131
* @summary Test that Ideal transformations of LShiftINode* are being performed as expected.
3232
* @library /test/lib /
3333
* @run driver compiler.c2.irTests.LShiftINodeIdealizationTests
@@ -37,15 +37,21 @@ public static void main(String[] args) {
3737
TestFramework.run();
3838
}
3939

40-
@Run(test = { "test1", "test2" })
40+
@Run(test = { "test1", "test2", "test3", "test4", "test5", "test6", "test7", "test8" })
4141
public void runMethod() {
4242
int a = RunInfo.getRandom().nextInt();
43+
int b = RunInfo.getRandom().nextInt();
44+
int c = RunInfo.getRandom().nextInt();
45+
int d = RunInfo.getRandom().nextInt();
4346

4447
int min = Integer.MIN_VALUE;
4548
int max = Integer.MAX_VALUE;
4649

4750
assertResult(0);
4851
assertResult(a);
52+
assertResult(b);
53+
assertResult(c);
54+
assertResult(d);
4955
assertResult(min);
5056
assertResult(max);
5157
}
@@ -54,6 +60,12 @@ public void runMethod() {
5460
public void assertResult(int a) {
5561
Asserts.assertEQ((a >> 2022) << 2022, test1(a));
5662
Asserts.assertEQ((a >>> 2022) << 2022, test2(a));
63+
Asserts.assertEQ((a >> 4) << 8, test3(a));
64+
Asserts.assertEQ((a >>> 4) << 8, test4(a));
65+
Asserts.assertEQ((a >> 8) << 4, test5(a));
66+
Asserts.assertEQ((a >>> 8) << 4, test6(a));
67+
Asserts.assertEQ(((a >> 4) & 0xFF) << 8, test7(a));
68+
Asserts.assertEQ(((a >>> 4) & 0xFF) << 8, test8(a));
5769
}
5870

5971
@Test
@@ -71,4 +83,52 @@ public int test1(int x) {
7183
public int test2(int x) {
7284
return (x >>> 2022) << 2022;
7385
}
86+
87+
@Test
88+
@IR(failOn = { IRNode.RSHIFT })
89+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
90+
// Checks (x >> 4) << 8 => (x << 4) & -16
91+
public int test3(int x) {
92+
return (x >> 4) << 8;
93+
}
94+
95+
@Test
96+
@IR(failOn = { IRNode.URSHIFT })
97+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
98+
// Checks (x >>> 4) << 8 => (x << 4) & -16
99+
public int test4(int x) {
100+
return (x >>> 4) << 8;
101+
}
102+
103+
@Test
104+
@IR(failOn = { IRNode.LSHIFT })
105+
@IR(counts = { IRNode.AND, "1", IRNode.RSHIFT, "1" })
106+
// Checks (x >> 8) << 4 => (x >> 4) & -16
107+
public int test5(int x) {
108+
return (x >> 8) << 4;
109+
}
110+
111+
@Test
112+
@IR(failOn = { IRNode.LSHIFT })
113+
@IR(counts = { IRNode.AND, "1", IRNode.URSHIFT, "1" })
114+
// Checks (x >>> 8) << 4 => (x >>> 4) & -16
115+
public int test6(int x) {
116+
return (x >>> 8) << 4;
117+
}
118+
119+
@Test
120+
@IR(failOn = { IRNode.RSHIFT })
121+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
122+
// Checks ((x >> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
123+
public int test7(int x) {
124+
return ((x >> 4) & 0xFF) << 8;
125+
}
126+
127+
@Test
128+
@IR(failOn = { IRNode.URSHIFT })
129+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
130+
// Checks ((x >>> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
131+
public int test8(int x) {
132+
return ((x >>> 4) & 0xFF) << 8;
133+
}
74134
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
package compiler.c2.irTests;
24+
25+
import jdk.test.lib.Asserts;
26+
import compiler.lib.ir_framework.*;
27+
28+
/*
29+
* @test
30+
* @bug 8303238
31+
* @summary Test that Ideal transformations of LShiftLNode* are being performed as expected.
32+
* @library /test/lib /
33+
* @run driver compiler.c2.irTests.LShiftLNodeIdealizationTests
34+
*/
35+
public class LShiftLNodeIdealizationTests {
36+
public static void main(String[] args) {
37+
TestFramework.run();
38+
}
39+
40+
@Run(test = { "test3", "test4", "test5", "test6", "test7", "test8" })
41+
public void runMethod() {
42+
long a = RunInfo.getRandom().nextLong();
43+
long b = RunInfo.getRandom().nextLong();
44+
long c = RunInfo.getRandom().nextLong();
45+
long d = RunInfo.getRandom().nextLong();
46+
47+
long min = Long.MIN_VALUE;
48+
long max = Long.MAX_VALUE;
49+
50+
assertResult(0);
51+
assertResult(a);
52+
assertResult(b);
53+
assertResult(c);
54+
assertResult(d);
55+
assertResult(min);
56+
assertResult(max);
57+
}
58+
59+
@DontCompile
60+
public void assertResult(long a) {
61+
Asserts.assertEQ((a >> 4L) << 8L, test3(a));
62+
Asserts.assertEQ((a >>> 4L) << 8L, test4(a));
63+
Asserts.assertEQ((a >> 8L) << 4L, test5(a));
64+
Asserts.assertEQ((a >>> 8L) << 4L, test6(a));
65+
Asserts.assertEQ(((a >> 4L) & 0xFFL) << 8L, test7(a));
66+
Asserts.assertEQ(((a >>> 4L) & 0xFFL) << 8L, test8(a));
67+
}
68+
69+
@Test
70+
@IR(failOn = { IRNode.RSHIFT })
71+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
72+
// Checks (x >> 4) << 8 => (x << 4) & -16
73+
public long test3(long x) {
74+
return (x >> 4L) << 8L;
75+
}
76+
77+
@Test
78+
@IR(failOn = { IRNode.URSHIFT })
79+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
80+
// Checks (x >>> 4) << 8 => (x << 4) & -16
81+
public long test4(long x) {
82+
return (x >>> 4L) << 8L;
83+
}
84+
85+
@Test
86+
@IR(failOn = { IRNode.LSHIFT })
87+
@IR(counts = { IRNode.AND, "1", IRNode.RSHIFT, "1" })
88+
// Checks (x >> 8) << 4 => (x >> 4) & -16
89+
public long test5(long x) {
90+
return (x >> 8L) << 4L;
91+
}
92+
93+
@Test
94+
@IR(failOn = { IRNode.LSHIFT })
95+
@IR(counts = { IRNode.AND, "1", IRNode.URSHIFT, "1" })
96+
// Checks (x >>> 8) << 4 => (x >>> 4) & -16
97+
public long test6(long x) {
98+
return (x >>> 8L) << 4L;
99+
}
100+
101+
@Test
102+
@IR(failOn = { IRNode.RSHIFT })
103+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
104+
// Checks ((x >> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
105+
public long test7(long x) {
106+
return ((x >> 4L) & 0xFFL) << 8L;
107+
}
108+
109+
@Test
110+
@IR(failOn = { IRNode.URSHIFT })
111+
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
112+
// Checks ((x >>> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
113+
public long test8(long x) {
114+
return ((x >>> 4L) & 0xFFL) << 8L;
115+
}
116+
}

0 commit comments

Comments
 (0)