Skip to content

Commit dc55a7f

Browse files
committed
8302202: Incorrect desugaring of null-allowed nested patterns
Reviewed-by: vromero
1 parent c4ffe4b commit dc55a7f

File tree

2 files changed

+180
-8
lines changed

2 files changed

+180
-8
lines changed

src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java

+17-8
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ public void visitTypeTest(JCInstanceOf tree) {
232232
}
233233

234234
Type principalType = types.erasure(TreeInfo.primaryPatternType((pattern)));
235-
JCExpression resultExpression= (JCExpression) this.<JCTree>translate(pattern);
236-
if (!tree.allowNull || !types.isSubtype(currentValue.type, principalType)) {
235+
JCExpression resultExpression = (JCExpression) this.<JCTree>translate(pattern);
236+
if (!tree.allowNull && !principalType.isPrimitive()) {
237237
resultExpression =
238238
makeBinary(Tag.AND,
239239
makeTypeTest(make.Ident(currentValue), make.Type(principalType)),
@@ -330,7 +330,8 @@ private UnrolledRecordPattern unrollRecordPattern(JCRecordPattern recordPattern)
330330
allowNull = false;
331331
} else {
332332
nestedBinding = (JCBindingPattern) nestedPattern;
333-
allowNull = true;
333+
allowNull = types.isSubtype(componentType,
334+
types.boxedTypeOrType(types.erasure(nestedBinding.type)));
334335
}
335336
JCMethodInvocation componentAccessor =
336337
make.App(make.Select(convert(make.Ident(recordBinding), recordBinding.type), //TODO - cast needed????
@@ -710,6 +711,7 @@ public void resolve(VarSymbol commonBinding,
710711
"commonNestedExpression: " + commonNestedExpression +
711712
"commonNestedBinding: " + commonNestedBinding);
712713
ListBuffer<JCCase> nestedCases = new ListBuffer<>();
714+
JCExpression lastGuard = null;
713715

714716
for(List<JCCase> accList = accummulator.toList(); accList.nonEmpty(); accList = accList.tail) {
715717
var accummulated = accList.head;
@@ -734,8 +736,6 @@ public void resolve(VarSymbol commonBinding,
734736
JCBindingPattern binding = (JCBindingPattern) instanceofCheck.pattern;
735737
hasUnconditional =
736738
instanceofCheck.allowNull &&
737-
types.isSubtype(commonNestedExpression.type,
738-
types.boxedTypeOrType(types.erasure(binding.type))) &&
739739
accList.tail.isEmpty();
740740
List<JCCaseLabel> newLabel;
741741
if (hasUnconditional) {
@@ -746,13 +746,16 @@ public void resolve(VarSymbol commonBinding,
746746
}
747747
appendBreakIfNeeded(currentSwitch, accummulated);
748748
nestedCases.add(make.Case(CaseKind.STATEMENT, newLabel, accummulated.stats, null));
749+
lastGuard = newGuard;
749750
}
750-
if (!hasUnconditional) {
751+
if (lastGuard != null || !hasUnconditional) {
751752
JCContinue continueSwitch = make.Continue(null);
752753
continueSwitch.target = currentSwitch;
753754
nestedCases.add(make.Case(CaseKind.STATEMENT,
754-
List.of(make.ConstantCaseLabel(makeNull()),
755-
make.DefaultCaseLabel()),
755+
hasUnconditional
756+
? List.of(make.DefaultCaseLabel())
757+
: List.of(make.ConstantCaseLabel(makeNull()),
758+
make.DefaultCaseLabel()),
756759
List.of(continueSwitch),
757760
null));
758761
}
@@ -774,9 +777,11 @@ public void resolve(VarSymbol commonBinding,
774777
VarSymbol commonBinding = null;
775778
JCExpression commonNestedExpression = null;
776779
VarSymbol commonNestedBinding = null;
780+
boolean previousNullable = false;
777781

778782
for (List<JCCase> c = inputCases; c.nonEmpty(); c = c.tail) {
779783
VarSymbol currentBinding = null;
784+
boolean currentNullable = false;
780785
JCExpression currentNestedExpression = null;
781786
VarSymbol currentNestedBinding = null;
782787

@@ -786,11 +791,13 @@ public void resolve(VarSymbol commonBinding,
786791
binOp.lhs instanceof JCInstanceOf instanceofCheck &&
787792
instanceofCheck.pattern instanceof JCBindingPattern binding) {
788793
currentBinding = ((JCBindingPattern) patternLabel.pat).var.sym;
794+
currentNullable = instanceofCheck.allowNull;
789795
currentNestedExpression = instanceofCheck.expr;
790796
currentNestedBinding = binding.var.sym;
791797
} else if (patternLabel.guard instanceof JCInstanceOf instanceofCheck &&
792798
instanceofCheck.pattern instanceof JCBindingPattern binding) {
793799
currentBinding = ((JCBindingPattern) patternLabel.pat).var.sym;
800+
currentNullable = instanceofCheck.allowNull;
794801
currentNestedExpression = instanceofCheck.expr;
795802
currentNestedBinding = binding.var.sym;
796803
}
@@ -806,6 +813,7 @@ public void resolve(VarSymbol commonBinding,
806813
}
807814
} else if (currentBinding != null &&
808815
commonBinding.type.tsym == currentBinding.type.tsym &&
816+
!previousNullable &&
809817
new TreeDiffer(List.of(commonBinding), List.of(currentBinding))
810818
.scan(commonNestedExpression, currentNestedExpression)) {
811819
accummulator.add(c.head);
@@ -820,6 +828,7 @@ public void resolve(VarSymbol commonBinding,
820828
commonNestedExpression = currentNestedExpression;
821829
commonNestedBinding = currentNestedBinding;
822830
}
831+
previousNullable = currentNullable;
823832
}
824833
resolveAccummulator.resolve(commonBinding, commonNestedExpression, commonNestedBinding);
825834
return result.toList();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
24+
/*
25+
* @test
26+
* @bug 8302202
27+
* @summary Testing record patterns with null components
28+
* @enablePreview
29+
* @compile NullsInDeconstructionPatterns2.java
30+
* @run main NullsInDeconstructionPatterns2
31+
*/
32+
33+
import java.util.Objects;
34+
import java.util.function.Function;
35+
36+
public class NullsInDeconstructionPatterns2 {
37+
38+
public static void main(String[] args) {
39+
new NullsInDeconstructionPatterns2().run();
40+
}
41+
42+
private void run() {
43+
run1(this::test1a);
44+
run1(this::test1b);
45+
run2(this::test2a);
46+
run2(this::test2b);
47+
run3(this::test3a);
48+
run3(this::test3b);
49+
run4();
50+
}
51+
52+
private void run1(Function<Object, String> method) {
53+
assertEquals("R1(null)", method.apply(new R1(null)));
54+
assertEquals("R1(!null)", method.apply(new R1("")));
55+
}
56+
57+
private void run2(Function<Object, String> method) {
58+
assertEquals("R2(null, null)", method.apply(new R2(null, null)));
59+
assertEquals("R2(!null, null)", method.apply(new R2("", null)));
60+
assertEquals("R2(null, !null)", method.apply(new R2(null, "")));
61+
assertEquals("R2(!null, !null)", method.apply(new R2("", "")));
62+
}
63+
64+
private void run3(Function<Object, String> method) {
65+
assertEquals("R3(null, null, null)", method.apply(new R3(null, null, null)));
66+
assertEquals("R3(!null, null, null)", method.apply(new R3("", null, null)));
67+
assertEquals("R3(null, !null, null)", method.apply(new R3(null, "", null)));
68+
assertEquals("R3(!null, !null, null)", method.apply(new R3("", "", null)));
69+
assertEquals("R3(null, null, !null)", method.apply(new R3(null, null, "")));
70+
assertEquals("R3(!null, null, !null)", method.apply(new R3("", null, "")));
71+
assertEquals("R3(null, !null, !null)", method.apply(new R3(null, "", "")));
72+
assertEquals("R3(!null, !null, !null)", method.apply(new R3("", "", "")));
73+
}
74+
75+
private void run4() {
76+
assertEquals("integer", test4(new R1(0)));
77+
assertEquals("empty", test4(new R1("")));
78+
assertEquals("default", test4(new R1("a")));
79+
}
80+
private String test1a(Object i) {
81+
return switch (i) {
82+
case R1(Object o) when o == null -> "R1(null)";
83+
case R1(Object o) when o != null -> "R1(!null)";
84+
default -> "default";
85+
};
86+
}
87+
88+
private String test1b(Object i) {
89+
return switch (i) {
90+
case R1(Object o) when o == null -> "R1(null)";
91+
case R1(Object o) -> "R1(!null)";
92+
default -> "default";
93+
};
94+
}
95+
96+
private String test2a(Object i) {
97+
return switch (i) {
98+
case R2(Object o1, Object o2) when o1 == null && o2 == null -> "R2(null, null)";
99+
case R2(Object o1, Object o2) when o1 != null && o2 == null -> "R2(!null, null)";
100+
case R2(Object o1, Object o2) when o1 == null && o2 != null -> "R2(null, !null)";
101+
case R2(Object o1, Object o2) when o1 != null && o2 != null -> "R2(!null, !null)";
102+
default -> "default";
103+
};
104+
}
105+
106+
private String test2b(Object i) {
107+
return switch (i) {
108+
case R2(Object o1, Object o2) when o1 == null && o2 == null -> "R2(null, null)";
109+
case R2(Object o1, Object o2) when o1 != null && o2 == null -> "R2(!null, null)";
110+
case R2(Object o1, Object o2) when o1 == null && o2 != null -> "R2(null, !null)";
111+
case R2(Object o1, Object o2) -> "R2(!null, !null)";
112+
default -> "default";
113+
};
114+
}
115+
116+
private String test3a(Object i) {
117+
return switch (i) {
118+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 == null -> "R3(null, null, null)";
119+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 == null -> "R3(!null, null, null)";
120+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 == null -> "R3(null, !null, null)";
121+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 != null && o3 == null -> "R3(!null, !null, null)";
122+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 != null -> "R3(null, null, !null)";
123+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 != null -> "R3(!null, null, !null)";
124+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 != null -> "R3(null, !null, !null)";
125+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 != null && o3 != null -> "R3(!null, !null, !null)";
126+
default -> "default";
127+
};
128+
}
129+
130+
private String test3b(Object i) {
131+
return switch (i) {
132+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 == null -> "R3(null, null, null)";
133+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 == null -> "R3(!null, null, null)";
134+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 == null -> "R3(null, !null, null)";
135+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 != null && o3 == null -> "R3(!null, !null, null)";
136+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 != null -> "R3(null, null, !null)";
137+
case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 != null -> "R3(!null, null, !null)";
138+
case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 != null -> "R3(null, !null, !null)";
139+
case R3(Object o1, Object o2, Object o3) -> "R3(!null, !null, !null)";
140+
default -> "default";
141+
};
142+
}
143+
144+
private String test4(Object i) {
145+
return switch (i) {
146+
case R1(Integer o) -> "integer";
147+
case R1(Object o) when o.toString().isEmpty() -> "empty";
148+
default -> "default";
149+
};
150+
}
151+
152+
private static void assertEquals(String expected, String actual) {
153+
if (!Objects.equals(expected, actual)) {
154+
throw new AssertionError("Unexpected result, expected: " + expected + "," +
155+
" actual: " + actual);
156+
}
157+
}
158+
159+
record R1(Object o) {}
160+
record R2(Object o1, Object o2) {}
161+
record R3(Object o1, Object o2, Object o3) {}
162+
163+
}

0 commit comments

Comments
 (0)