Skip to content

Commit 29d648c

Browse files
committed
8341781: Improve Min/Max node identities
Reviewed-by: chagedorn
1 parent 4c39e9f commit 29d648c

File tree

5 files changed

+295
-8
lines changed

5 files changed

+295
-8
lines changed

src/hotspot/share/opto/addnode.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,20 @@ Node* MaxINode::Ideal(PhaseGVN* phase, bool can_reshape) {
14131413
return IdealI(phase, can_reshape);
14141414
}
14151415

1416+
Node* MaxINode::Identity(PhaseGVN* phase) {
1417+
const TypeInt* t1 = phase->type(in(1))->is_int();
1418+
const TypeInt* t2 = phase->type(in(2))->is_int();
1419+
1420+
// Can we determine the maximum statically?
1421+
if (t1->_lo >= t2->_hi) {
1422+
return in(1);
1423+
} else if (t2->_lo >= t1->_hi) {
1424+
return in(2);
1425+
}
1426+
1427+
return MaxNode::Identity(phase);
1428+
}
1429+
14161430
//=============================================================================
14171431
//------------------------------add_ring---------------------------------------
14181432
// Supplied function returns the sum of the inputs.
@@ -1432,6 +1446,20 @@ Node* MinINode::Ideal(PhaseGVN* phase, bool can_reshape) {
14321446
return IdealI(phase, can_reshape);
14331447
}
14341448

1449+
Node* MinINode::Identity(PhaseGVN* phase) {
1450+
const TypeInt* t1 = phase->type(in(1))->is_int();
1451+
const TypeInt* t2 = phase->type(in(2))->is_int();
1452+
1453+
// Can we determine the minimum statically?
1454+
if (t1->_lo >= t2->_hi) {
1455+
return in(2);
1456+
} else if (t2->_lo >= t1->_hi) {
1457+
return in(1);
1458+
}
1459+
1460+
return MaxNode::Identity(phase);
1461+
}
1462+
14351463
//------------------------------add_ring---------------------------------------
14361464
// Supplied function returns the sum of the inputs.
14371465
const Type *MinINode::add_ring( const Type *t0, const Type *t1 ) const {
@@ -1574,11 +1602,56 @@ Node* MinLNode::Ideal(PhaseGVN* phase, bool can_reshape) {
15741602
return nullptr;
15751603
}
15761604

1605+
int MaxNode::opposite_opcode() const {
1606+
if (Opcode() == max_opcode()) {
1607+
return min_opcode();
1608+
} else {
1609+
assert(Opcode() == min_opcode(), "Caller should be either %s or %s, but is %s", NodeClassNames[max_opcode()], NodeClassNames[min_opcode()], NodeClassNames[Opcode()]);
1610+
return max_opcode();
1611+
}
1612+
}
1613+
1614+
// Given a redundant structure such as Max/Min(A, Max/Min(B, C)) where A == B or A == C, return the useful part of the structure.
1615+
// 'operation' is the node expected to be the inner 'Max/Min(B, C)', and 'operand' is the node expected to be the 'A' operand of the outer node.
1616+
Node* MaxNode::find_identity_operation(Node* operation, Node* operand) {
1617+
if (operation->Opcode() == Opcode() || operation->Opcode() == opposite_opcode()) {
1618+
Node* n1 = operation->in(1);
1619+
Node* n2 = operation->in(2);
1620+
1621+
// Given Op(A, Op(B, C)), see if either A == B or A == C is true.
1622+
if (n1 == operand || n2 == operand) {
1623+
// If the operations are the same return the inner operation, as Max(A, Max(A, B)) == Max(A, B).
1624+
if (operation->Opcode() == Opcode()) {
1625+
return operation;
1626+
}
1627+
1628+
// If the operations are different return the operand 'A', as Max(A, Min(A, B)) == A if the value isn't floating point.
1629+
// With floating point values, the identity doesn't hold if B == NaN.
1630+
const Type* type = bottom_type();
1631+
if (type->isa_int() || type->isa_long()) {
1632+
return operand;
1633+
}
1634+
}
1635+
}
1636+
1637+
return nullptr;
1638+
}
1639+
15771640
Node* MaxNode::Identity(PhaseGVN* phase) {
15781641
if (in(1) == in(2)) {
15791642
return in(1);
15801643
}
15811644

1645+
Node* identity_1 = MaxNode::find_identity_operation(in(2), in(1));
1646+
if (identity_1 != nullptr) {
1647+
return identity_1;
1648+
}
1649+
1650+
Node* identity_2 = MaxNode::find_identity_operation(in(1), in(2));
1651+
if (identity_2 != nullptr) {
1652+
return identity_2;
1653+
}
1654+
15821655
return AddNode::Identity(phase);
15831656
}
15841657

src/hotspot/share/opto/addnode.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ class XorLNode : public AddNode {
262262

263263
//------------------------------MaxNode----------------------------------------
264264
// Max (or min) of 2 values. Included with the ADD nodes because it inherits
265-
// all the behavior of addition on a ring. Only new thing is that we allow
266-
// 2 equal inputs to be equal.
265+
// all the behavior of addition on a ring.
267266
class MaxNode : public AddNode {
268267
private:
269268
static Node* build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn);
@@ -277,6 +276,8 @@ class MaxNode : public AddNode {
277276
virtual int min_opcode() const = 0;
278277
Node* IdealI(PhaseGVN* phase, bool can_reshape);
279278
virtual Node* Identity(PhaseGVN* phase);
279+
Node* find_identity_operation(Node* operation, Node* operand);
280+
int opposite_opcode() const;
280281

281282
static Node* unsigned_max(Node* a, Node* b, const Type* t, PhaseGVN& gvn) {
282283
return build_min_max(a, b, true, true, t, gvn);
@@ -321,6 +322,7 @@ class MaxINode : public MaxNode {
321322
virtual uint ideal_reg() const { return Op_RegI; }
322323
int max_opcode() const { return Op_MaxI; }
323324
int min_opcode() const { return Op_MinI; }
325+
virtual Node* Identity(PhaseGVN* phase);
324326
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
325327
};
326328

@@ -337,7 +339,8 @@ class MinINode : public MaxNode {
337339
virtual uint ideal_reg() const { return Op_RegI; }
338340
int max_opcode() const { return Op_MaxI; }
339341
int min_opcode() const { return Op_MinI; }
340-
virtual Node *Ideal(PhaseGVN *phase, bool can_reshape);
342+
virtual Node* Identity(PhaseGVN* phase);
343+
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
341344
};
342345

343346
//------------------------------MaxLNode---------------------------------------

test/hotspot/jtreg/compiler/c2/irTests/MaxMinINodeIdealizationTests.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
33
* Copyright (c) 2022, Arm Limited. All rights reserved.
44
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
55
*
@@ -28,7 +28,7 @@
2828

2929
/*
3030
* @test
31-
* @bug 8290248 8312547
31+
* @bug 8290248 8312547 8341781
3232
* @summary Test that Ideal transformations of MaxINode and MinINode are
3333
* being performed as expected.
3434
* @library /test/lib /
@@ -46,10 +46,12 @@ public static void main(String[] args) {
4646
"testMax2LNoLeftAdd",
4747
"testMax3",
4848
"testMax4",
49+
"testMax5",
4950
"testMin1",
5051
"testMin2",
5152
"testMin3",
52-
"testMin4"})
53+
"testMin4",
54+
"testMin5"})
5355
public void runPositiveTests() {
5456
int a = RunInfo.getRandom().nextInt();
5557
int min = Integer.MIN_VALUE;
@@ -76,11 +78,13 @@ public void assertPositiveResult(int a) {
7678
Asserts.assertEQ(Math.max(a >> 1, ((a >> 1) + 11)) , testMax2LNoLeftAdd(a));
7779
Asserts.assertEQ(Math.max(a, a) , testMax3(a));
7880
Asserts.assertEQ(0 , testMax4(a));
81+
Asserts.assertEQ(8 , testMax5(a));
7982

8083
Asserts.assertEQ(Math.min(((a >> 1) + 100), Math.min(((a >> 1) + 150), 200)), testMin1(a));
8184
Asserts.assertEQ(Math.min(((a >> 1) + 10), ((a >> 1) + 11)) , testMin2(a));
8285
Asserts.assertEQ(Math.min(a, a) , testMin3(a));
8386
Asserts.assertEQ(0 , testMin4(a));
87+
Asserts.assertEQ(a & 7 , testMin5(a));
8488
}
8589

8690
// The transformations in test*1 and test*2 can happen only if the compiler has enough information
@@ -219,6 +223,18 @@ public int testMin4(int i) {
219223
return Math.min(i, 0) > 0 ? 1 : 0;
220224
}
221225

226+
@Test
227+
@IR(failOn = {IRNode.MAX_I})
228+
public int testMax5(int i) {
229+
return Math.max(i & 7, 8);
230+
}
231+
232+
@Test
233+
@IR(failOn = {IRNode.MIN_I})
234+
public int testMin5(int i) {
235+
return Math.min(i & 7, 8);
236+
}
237+
222238
@Run(test = {"testTwoLevelsDifferentXY",
223239
"testTwoLevelsNoLeftConstant",
224240
"testTwoLevelsNoRightConstant",
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
package compiler.c2.irTests;
24+
25+
import jdk.test.lib.Asserts;
26+
import compiler.lib.ir_framework.*;
27+
import java.util.Random;
28+
import jdk.test.lib.Utils;
29+
30+
/*
31+
* @test
32+
* @bug 8341781
33+
* @summary Test identities of MinNodes and MaxNodes.
34+
* @key randomness
35+
* @library /test/lib /
36+
* @run driver compiler.c2.irTests.TestMinMaxIdentities
37+
*/
38+
39+
public class TestMinMaxIdentities {
40+
private static final Random RANDOM = Utils.getRandomInstance();
41+
42+
public static void main(String[] args) {
43+
TestFramework.run();
44+
}
45+
46+
@Run(test = { "intMinMin", "intMinMax", "intMaxMin", "intMaxMax",
47+
"longMinMin", "longMinMax", "longMaxMin", "longMaxMax",
48+
"floatMinMin", "floatMaxMax", "doubleMinMin", "doubleMaxMax",
49+
"floatMinMax", "floatMaxMin", "doubleMinMax", "doubleMaxMin" })
50+
public void runMethod() {
51+
assertResult(10, 20, 10L, 20L, 10.f, 20.f, 10.0, 20.0);
52+
assertResult(20, 10, 20L, 10L, 20.f, 10.f, 20.0, 10.0);
53+
54+
assertResult(RANDOM.nextInt(), RANDOM.nextInt(), RANDOM.nextLong(), RANDOM.nextLong(), RANDOM.nextFloat(), RANDOM.nextFloat(), RANDOM.nextDouble(), RANDOM.nextDouble());
55+
assertResult(RANDOM.nextInt(), RANDOM.nextInt(), RANDOM.nextLong(), RANDOM.nextLong(), RANDOM.nextFloat(), RANDOM.nextFloat(), RANDOM.nextDouble(), RANDOM.nextDouble());
56+
57+
assertResult(Integer.MAX_VALUE, Integer.MIN_VALUE, Long.MAX_VALUE, Long.MIN_VALUE, Float.POSITIVE_INFINITY, Float.NaN, Double.POSITIVE_INFINITY, Double.NaN);
58+
assertResult(Integer.MIN_VALUE, Integer.MAX_VALUE, Long.MIN_VALUE, Long.MAX_VALUE, Float.NaN, Float.POSITIVE_INFINITY, Double.NaN, Double.POSITIVE_INFINITY);
59+
}
60+
61+
@DontCompile
62+
public void assertResult(int iA, int iB, long lA, long lB, float fA, float fB, double dA, double dB) {
63+
Asserts.assertEQ(Math.min(iA, Math.min(iA, iB)), intMinMin(iA, iB));
64+
Asserts.assertEQ(Math.min(iA, Math.max(iA, iB)), intMinMax(iA, iB));
65+
Asserts.assertEQ(Math.max(iA, Math.min(iA, iB)), intMaxMin(iA, iB));
66+
Asserts.assertEQ(Math.max(iA, Math.max(iA, iB)), intMaxMax(iA, iB));
67+
68+
Asserts.assertEQ(Math.min(lA, Math.min(lA, lB)), longMinMin(lA, lB));
69+
Asserts.assertEQ(Math.min(lA, Math.max(lA, lB)), longMinMax(lA, lB));
70+
Asserts.assertEQ(Math.max(lA, Math.min(lA, lB)), longMaxMin(lA, lB));
71+
Asserts.assertEQ(Math.max(lA, Math.max(lA, lB)), longMaxMax(lA, lB));
72+
73+
Asserts.assertEQ(Math.min(fA, Math.min(fA, fB)), floatMinMin(fA, fB));
74+
Asserts.assertEQ(Math.max(fA, Math.max(fA, fB)), floatMaxMax(fA, fB));
75+
76+
Asserts.assertEQ(Math.min(dA, Math.min(dA, dB)), doubleMinMin(dA, dB));
77+
Asserts.assertEQ(Math.max(dA, Math.max(dA, dB)), doubleMaxMax(dA, dB));
78+
79+
// Due to NaN, these identities cannot be simplified.
80+
81+
Asserts.assertEQ(Math.min(fA, Math.max(fA, fB)), floatMinMax(fA, fB));
82+
Asserts.assertEQ(Math.max(fA, Math.min(fA, fB)), floatMaxMin(fA, fB));
83+
Asserts.assertEQ(Math.min(dA, Math.max(dA, dB)), doubleMinMax(dA, dB));
84+
Asserts.assertEQ(Math.max(dA, Math.min(dA, dB)), doubleMaxMin(dA, dB));
85+
}
86+
87+
// Integers
88+
89+
@Test
90+
@IR(counts = { IRNode.MIN_I, "1" })
91+
public int intMinMin(int a, int b) {
92+
return Math.min(a, Math.min(a, b));
93+
}
94+
95+
@Test
96+
@IR(failOn = { IRNode.MIN_I, IRNode.MAX_I })
97+
public int intMinMax(int a, int b) {
98+
return Math.min(a, Math.max(a, b));
99+
}
100+
101+
@Test
102+
@IR(failOn = { IRNode.MIN_I, IRNode.MAX_I })
103+
public int intMaxMin(int a, int b) {
104+
return Math.max(a, Math.min(a, b));
105+
}
106+
107+
@Test
108+
@IR(counts = { IRNode.MAX_I, "1" })
109+
public int intMaxMax(int a, int b) {
110+
return Math.max(a, Math.max(a, b));
111+
}
112+
113+
// Longs
114+
115+
// As Math.min/max(LL) is not intrinsified, it first needs to be transformed into CMoveL and then MinL/MaxL before
116+
// the identity can be matched. However, the outer min/max is not transformed into CMove because of the CMove cost model.
117+
// As JDK-8307513 adds intrinsics for the methods, the tests will be updated then.
118+
119+
@Test
120+
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MIN_L, "1" })
121+
public long longMinMin(long a, long b) {
122+
return Math.min(a, Math.min(a, b));
123+
}
124+
125+
@Test
126+
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MIN_L, "1" })
127+
public long longMinMax(long a, long b) {
128+
return Math.min(a, Math.max(a, b));
129+
}
130+
131+
@Test
132+
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MAX_L, "1" })
133+
public long longMaxMin(long a, long b) {
134+
return Math.max(a, Math.min(a, b));
135+
}
136+
137+
@Test
138+
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MAX_L, "1" })
139+
public long longMaxMax(long a, long b) {
140+
return Math.max(a, Math.max(a, b));
141+
}
142+
143+
// Floats
144+
145+
@Test
146+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_F, "1" })
147+
public float floatMinMin(float a, float b) {
148+
return Math.min(a, Math.min(a, b));
149+
}
150+
151+
@Test
152+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MAX_F, "1" })
153+
public float floatMaxMax(float a, float b) {
154+
return Math.max(a, Math.max(a, b));
155+
}
156+
157+
// Doubles
158+
159+
@Test
160+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_D, "1" })
161+
public double doubleMinMin(double a, double b) {
162+
return Math.min(a, Math.min(a, b));
163+
}
164+
165+
@Test
166+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MAX_D, "1" })
167+
public double doubleMaxMax(double a, double b) {
168+
return Math.max(a, Math.max(a, b));
169+
}
170+
171+
// Float and double identities that cannot be simplified due to NaN
172+
173+
@Test
174+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_F, "1", IRNode.MAX_F, "1" })
175+
public float floatMinMax(float a, float b) {
176+
return Math.min(a, Math.max(a, b));
177+
}
178+
179+
@Test
180+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_F, "1", IRNode.MAX_F, "1" })
181+
public float floatMaxMin(float a, float b) {
182+
return Math.max(a, Math.min(a, b));
183+
}
184+
185+
@Test
186+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_D, "1", IRNode.MAX_D, "1" })
187+
public double doubleMinMax(double a, double b) {
188+
return Math.min(a, Math.max(a, b));
189+
}
190+
191+
@Test
192+
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_D, "1", IRNode.MAX_D, "1" })
193+
public double doubleMaxMin(double a, double b) {
194+
return Math.max(a, Math.min(a, b));
195+
}
196+
}

0 commit comments

Comments
 (0)