Skip to content

Commit

Permalink
8303238: Create generalizations for existing LShift ideal transforms
Browse files Browse the repository at this point in the history
Reviewed-by: redestad, thartmann
  • Loading branch information
jaskarth authored and cl4es committed Mar 13, 2023
1 parent 805a4e6 commit 8e41bf2
Show file tree
Hide file tree
Showing 4 changed files with 425 additions and 29 deletions.
158 changes: 132 additions & 26 deletions src/hotspot/share/opto/mulnode.cpp
Expand Up @@ -847,21 +847,74 @@ Node *LShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
}
}

// Check for "(x>>c0)<<c0" which just masks off low bits
if( (add1_op == Op_RShiftI || add1_op == Op_URShiftI ) &&
add1->in(2) == in(2) )
// Convert to "(x & -(1<<c0))"
return new AndINode(add1->in(1),phase->intcon( -(1<<con)));

// Check for "((x>>c0) & Y)<<c0" which just masks off more low bits
if( add1_op == Op_AndI ) {
// Check for "(x >> C1) << C2"
if (add1_op == Op_RShiftI || add1_op == Op_URShiftI) {
// Special case C1 == C2, which just masks off low bits
if (add1->in(2) == in(2)) {
// Convert to "(x & -(1 << C2))"
return new AndINode(add1->in(1), phase->intcon(-(1 << con)));
} else {
int add1Con = 0;
const_shift_count(phase, add1, &add1Con);

// Wait until the right shift has been sharpened to the correct count
if (add1Con > 0 && add1Con < BitsPerJavaInteger) {
// As loop parsing can produce LShiftI nodes, we should wait until the graph is fully formed
// to apply optimizations, otherwise we can inadvertently stop vectorization opportunities.
if (phase->is_IterGVN()) {
if (con > add1Con) {
// Creates "(x << (C2 - C1)) & -(1 << C2)"
Node* lshift = phase->transform(new LShiftINode(add1->in(1), phase->intcon(con - add1Con)));
return new AndINode(lshift, phase->intcon(-(1 << con)));
} else {
assert(con < add1Con, "must be (%d < %d)", con, add1Con);
// Creates "(x >> (C1 - C2)) & -(1 << C2)"

// Handle logical and arithmetic shifts
Node* rshift;
if (add1_op == Op_RShiftI) {
rshift = phase->transform(new RShiftINode(add1->in(1), phase->intcon(add1Con - con)));
} else {
rshift = phase->transform(new URShiftINode(add1->in(1), phase->intcon(add1Con - con)));
}

return new AndINode(rshift, phase->intcon(-(1 << con)));
}
} else {
phase->record_for_igvn(this);
}
}
}
}

// Check for "((x >> C1) & Y) << C2"
if (add1_op == Op_AndI) {
Node *add2 = add1->in(1);
int add2_op = add2->Opcode();
if( (add2_op == Op_RShiftI || add2_op == Op_URShiftI ) &&
add2->in(2) == in(2) ) {
// Convert to "(x & (Y<<c0))"
Node *y_sh = phase->transform( new LShiftINode( add1->in(2), in(2) ) );
return new AndINode( add2->in(1), y_sh );
if (add2_op == Op_RShiftI || add2_op == Op_URShiftI) {
// Special case C1 == C2, which just masks off low bits
if (add2->in(2) == in(2)) {
// Convert to "(x & (Y << C2))"
Node* y_sh = phase->transform(new LShiftINode(add1->in(2), phase->intcon(con)));
return new AndINode(add2->in(1), y_sh);
}

int add2Con = 0;
const_shift_count(phase, add2, &add2Con);
if (add2Con > 0 && add2Con < BitsPerJavaInteger) {
if (phase->is_IterGVN()) {
// Convert to "((x >> C1) << C2) & (Y << C2)"

// Make "(x >> C1) << C2", which will get folded away by the rule above
Node* x_sh = phase->transform(new LShiftINode(add2, phase->intcon(con)));
// Make "Y << C2", which will simplify when Y is a constant
Node* y_sh = phase->transform(new LShiftINode(add1->in(2), phase->intcon(con)));

return new AndINode(x_sh, y_sh);
} else {
phase->record_for_igvn(this);
}
}
}
}

Expand Down Expand Up @@ -970,21 +1023,74 @@ Node *LShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
}
}

// Check for "(x>>c0)<<c0" which just masks off low bits
if( (add1_op == Op_RShiftL || add1_op == Op_URShiftL ) &&
add1->in(2) == in(2) )
// Convert to "(x & -(1<<c0))"
return new AndLNode(add1->in(1),phase->longcon( -(CONST64(1)<<con)));
// Check for "(x >> C1) << C2"
if (add1_op == Op_RShiftL || add1_op == Op_URShiftL) {
// Special case C1 == C2, which just masks off low bits
if (add1->in(2) == in(2)) {
// Convert to "(x & -(1 << C2))"
return new AndLNode(add1->in(1), phase->longcon(-(CONST64(1) << con)));
} else {
int add1Con = 0;
const_shift_count(phase, add1, &add1Con);

// Wait until the right shift has been sharpened to the correct count
if (add1Con > 0 && add1Con < BitsPerJavaLong) {
// As loop parsing can produce LShiftI nodes, we should wait until the graph is fully formed
// to apply optimizations, otherwise we can inadvertently stop vectorization opportunities.
if (phase->is_IterGVN()) {
if (con > add1Con) {
// Creates "(x << (C2 - C1)) & -(1 << C2)"
Node* lshift = phase->transform(new LShiftLNode(add1->in(1), phase->intcon(con - add1Con)));
return new AndLNode(lshift, phase->longcon(-(CONST64(1) << con)));
} else {
assert(con < add1Con, "must be (%d < %d)", con, add1Con);
// Creates "(x >> (C1 - C2)) & -(1 << C2)"

// Handle logical and arithmetic shifts
Node* rshift;
if (add1_op == Op_RShiftL) {
rshift = phase->transform(new RShiftLNode(add1->in(1), phase->intcon(add1Con - con)));
} else {
rshift = phase->transform(new URShiftLNode(add1->in(1), phase->intcon(add1Con - con)));
}

return new AndLNode(rshift, phase->longcon(-(CONST64(1) << con)));
}
} else {
phase->record_for_igvn(this);
}
}
}
}

// Check for "((x>>c0) & Y)<<c0" which just masks off more low bits
if( add1_op == Op_AndL ) {
Node *add2 = add1->in(1);
// Check for "((x >> C1) & Y) << C2"
if (add1_op == Op_AndL) {
Node* add2 = add1->in(1);
int add2_op = add2->Opcode();
if( (add2_op == Op_RShiftL || add2_op == Op_URShiftL ) &&
add2->in(2) == in(2) ) {
// Convert to "(x & (Y<<c0))"
Node *y_sh = phase->transform( new LShiftLNode( add1->in(2), in(2) ) );
return new AndLNode( add2->in(1), y_sh );
if (add2_op == Op_RShiftL || add2_op == Op_URShiftL) {
// Special case C1 == C2, which just masks off low bits
if (add2->in(2) == in(2)) {
// Convert to "(x & (Y << C2))"
Node* y_sh = phase->transform(new LShiftLNode(add1->in(2), phase->intcon(con)));
return new AndLNode(add2->in(1), y_sh);
}

int add2Con = 0;
const_shift_count(phase, add2, &add2Con);
if (add2Con > 0 && add2Con < BitsPerJavaLong) {
if (phase->is_IterGVN()) {
// Convert to "((x >> C1) << C2) & (Y << C2)"

// Make "(x >> C1) << C2", which will get folded away by the rule above
Node* x_sh = phase->transform(new LShiftLNode(add2, phase->intcon(con)));
// Make "Y << C2", which will simplify when Y is a constant
Node* y_sh = phase->transform(new LShiftLNode(add1->in(2), phase->intcon(con)));

return new AndLNode(x_sh, y_sh);
} else {
phase->record_for_igvn(this);
}
}
}
}

Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -27,7 +27,7 @@

/*
* @test
* @bug 8297384
* @bug 8297384 8303238
* @summary Test that Ideal transformations of LShiftINode* are being performed as expected.
* @library /test/lib /
* @run driver compiler.c2.irTests.LShiftINodeIdealizationTests
Expand All @@ -37,15 +37,21 @@ public static void main(String[] args) {
TestFramework.run();
}

@Run(test = { "test1", "test2" })
@Run(test = { "test1", "test2", "test3", "test4", "test5", "test6", "test7", "test8" })
public void runMethod() {
int a = RunInfo.getRandom().nextInt();
int b = RunInfo.getRandom().nextInt();
int c = RunInfo.getRandom().nextInt();
int d = RunInfo.getRandom().nextInt();

int min = Integer.MIN_VALUE;
int max = Integer.MAX_VALUE;

assertResult(0);
assertResult(a);
assertResult(b);
assertResult(c);
assertResult(d);
assertResult(min);
assertResult(max);
}
Expand All @@ -54,6 +60,12 @@ public void runMethod() {
public void assertResult(int a) {
Asserts.assertEQ((a >> 2022) << 2022, test1(a));
Asserts.assertEQ((a >>> 2022) << 2022, test2(a));
Asserts.assertEQ((a >> 4) << 8, test3(a));
Asserts.assertEQ((a >>> 4) << 8, test4(a));
Asserts.assertEQ((a >> 8) << 4, test5(a));
Asserts.assertEQ((a >>> 8) << 4, test6(a));
Asserts.assertEQ(((a >> 4) & 0xFF) << 8, test7(a));
Asserts.assertEQ(((a >>> 4) & 0xFF) << 8, test8(a));
}

@Test
Expand All @@ -71,4 +83,52 @@ public int test1(int x) {
public int test2(int x) {
return (x >>> 2022) << 2022;
}

@Test
@IR(failOn = { IRNode.RSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks (x >> 4) << 8 => (x << 4) & -16
public int test3(int x) {
return (x >> 4) << 8;
}

@Test
@IR(failOn = { IRNode.URSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks (x >>> 4) << 8 => (x << 4) & -16
public int test4(int x) {
return (x >>> 4) << 8;
}

@Test
@IR(failOn = { IRNode.LSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.RSHIFT, "1" })
// Checks (x >> 8) << 4 => (x >> 4) & -16
public int test5(int x) {
return (x >> 8) << 4;
}

@Test
@IR(failOn = { IRNode.LSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.URSHIFT, "1" })
// Checks (x >>> 8) << 4 => (x >>> 4) & -16
public int test6(int x) {
return (x >>> 8) << 4;
}

@Test
@IR(failOn = { IRNode.RSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks ((x >> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
public int test7(int x) {
return ((x >> 4) & 0xFF) << 8;
}

@Test
@IR(failOn = { IRNode.URSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks ((x >>> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
public int test8(int x) {
return ((x >>> 4) & 0xFF) << 8;
}
}
@@ -0,0 +1,116 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package compiler.c2.irTests;

import jdk.test.lib.Asserts;
import compiler.lib.ir_framework.*;

/*
* @test
* @bug 8303238
* @summary Test that Ideal transformations of LShiftLNode* are being performed as expected.
* @library /test/lib /
* @run driver compiler.c2.irTests.LShiftLNodeIdealizationTests
*/
public class LShiftLNodeIdealizationTests {
public static void main(String[] args) {
TestFramework.run();
}

@Run(test = { "test3", "test4", "test5", "test6", "test7", "test8" })
public void runMethod() {
long a = RunInfo.getRandom().nextLong();
long b = RunInfo.getRandom().nextLong();
long c = RunInfo.getRandom().nextLong();
long d = RunInfo.getRandom().nextLong();

long min = Long.MIN_VALUE;
long max = Long.MAX_VALUE;

assertResult(0);
assertResult(a);
assertResult(b);
assertResult(c);
assertResult(d);
assertResult(min);
assertResult(max);
}

@DontCompile
public void assertResult(long a) {
Asserts.assertEQ((a >> 4L) << 8L, test3(a));
Asserts.assertEQ((a >>> 4L) << 8L, test4(a));
Asserts.assertEQ((a >> 8L) << 4L, test5(a));
Asserts.assertEQ((a >>> 8L) << 4L, test6(a));
Asserts.assertEQ(((a >> 4L) & 0xFFL) << 8L, test7(a));
Asserts.assertEQ(((a >>> 4L) & 0xFFL) << 8L, test8(a));
}

@Test
@IR(failOn = { IRNode.RSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks (x >> 4) << 8 => (x << 4) & -16
public long test3(long x) {
return (x >> 4L) << 8L;
}

@Test
@IR(failOn = { IRNode.URSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks (x >>> 4) << 8 => (x << 4) & -16
public long test4(long x) {
return (x >>> 4L) << 8L;
}

@Test
@IR(failOn = { IRNode.LSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.RSHIFT, "1" })
// Checks (x >> 8) << 4 => (x >> 4) & -16
public long test5(long x) {
return (x >> 8L) << 4L;
}

@Test
@IR(failOn = { IRNode.LSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.URSHIFT, "1" })
// Checks (x >>> 8) << 4 => (x >>> 4) & -16
public long test6(long x) {
return (x >>> 8L) << 4L;
}

@Test
@IR(failOn = { IRNode.RSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks ((x >> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
public long test7(long x) {
return ((x >> 4L) & 0xFFL) << 8L;
}

@Test
@IR(failOn = { IRNode.URSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks ((x >>> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
public long test8(long x) {
return ((x >>> 4L) & 0xFFL) << 8L;
}
}

1 comment on commit 8e41bf2

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.