Skip to content

Commit

Permalink
Support for multiple patterns in a case statement (JEP 456)
Browse files Browse the repository at this point in the history
Co-authored-by: David Thompson <davthomp@redhat.com>
Signed-off-by: David Thompson <davthomp@redhat.com>
  • Loading branch information
srikanth-sankaran and datho7561 committed Mar 5, 2024
1 parent 9e5215a commit cb6774f
Show file tree
Hide file tree
Showing 29 changed files with 1,701 additions and 360 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,11 @@ public interface IProblem {
*/
int IllegalRecordPattern = TypeRelated + 1941;

/**
* @since 3.37
*/
int NamedPatternVariablesDisallowedHere = Internal + 1942;


/**
* @since 3.35
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
package org.eclipse.jdt.internal.compiler.ast;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;

import org.eclipse.jdt.internal.compiler.ASTVisitor;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFileConstants;
Expand All @@ -37,67 +39,57 @@

public class CaseStatement extends Statement {

static final int CASE_CONSTANT = 1;
static final int CASE_PATTERN = 2;

public BranchLabel targetLabel;

// labels for guarded patterns
public BranchLabel falseLabel;
public BranchLabel trueLabel;

public Expression[] constantExpressions; // case with multiple expressions
public Expression[] constantExpressions; // case with multiple expressions - if you want a under-the-hood view, use peeledLabelExpressions()
public BranchLabel[] targetLabels; // for multiple expressions
public boolean isExpr = false;
/* package */ int patternIndex = -1; // points to first pattern var index [only one pattern variable allowed now - should be 0]

public CaseStatement(Expression constantExpression, int sourceEnd, int sourceStart) {
this(sourceEnd, sourceStart, constantExpression != null ? new Expression[] {constantExpression} : null);
}
public SwitchStatement swich; // owning switch
public int typeSwitchIndex; // for the first pattern among this.constantExpressions

public CaseStatement(int sourceEnd, int sourceStart, Expression[] constantExpressions) {
public CaseStatement(Expression[] constantExpressions, int sourceStart, int sourceEnd) {
this.constantExpressions = constantExpressions;
this.sourceEnd = sourceEnd;
this.sourceStart = sourceStart;
initPatterns();
this.sourceEnd = sourceEnd;
}

private void initPatterns() {
int l = this.constantExpressions == null ? 0 : this.constantExpressions.length;
for (int i = 0; i < l; ++i) {
Expression e = this.constantExpressions[i];
if (e instanceof Pattern) {
this.patternIndex = i;
break;
/** Provide an under-the-hood view of label expressions, peeling away any abstractions that package many expressions as one
* @return flattened array of label expressions
*/
public Expression [] peeledLabelExpressions() {
Expression [] constants = Expression.NO_EXPRESSIONS;
for (Expression e : this.constantExpressions) {
if (e instanceof Pattern p1) {
constants = Stream.concat(Arrays.stream(constants), Arrays.stream(p1.getAlternatives())).toArray(Expression[]::new);
} else {
constants = Stream.concat(Arrays.stream(constants), Stream.of(e)).toArray(Expression[]::new);
}
}
return constants;
}

@Override
public FlowInfo analyseCode(
BlockScope currentScope,
FlowContext flowContext,
FlowInfo flowInfo) {
if (this.constantExpressions != null) {
int nullPatternCount = 0;
for(int i=0; i < this.constantExpressions.length; i++) {
Expression e = this.constantExpressions[i];
for (LocalVariableBinding local : e.bindingsWhenTrue()) {
local.useFlag = LocalVariableBinding.USED; // these are structurally required even if not touched
}
nullPatternCount += e instanceof NullLiteral ? 1 : 0;
if (i > 0 && (e instanceof Pattern)) {
if (!(i == nullPatternCount && e instanceof TypePattern))
currentScope.problemReporter().IllegalFallThroughToPattern(e);
}
flowInfo = analyseConstantExpression(currentScope, flowContext, flowInfo, e);
if (nullPatternCount > 0 && e instanceof TypePattern) {
LocalVariableBinding binding = ((TypePattern) e).local.binding;
if (binding != null)
flowInfo.markNullStatus(binding, FlowInfo.POTENTIALLY_NULL);
}
public FlowInfo analyseCode(BlockScope currentScope, FlowContext flowContext, FlowInfo flowInfo) {

int nullPatternCount = 0;
for (int i = 0, length = this.constantExpressions.length; i < length; i++) {
Expression e = this.constantExpressions[i];
for (LocalVariableBinding local : e.bindingsWhenTrue()) {
local.useFlag = LocalVariableBinding.USED; // these are structurally required even if not touched
}
nullPatternCount += e instanceof NullLiteral ? 1 : 0;
if (i > 0 && (e instanceof Pattern) && !JavaFeature.UNNAMMED_PATTERNS_AND_VARS.isSupported(currentScope.compilerOptions().sourceLevel, currentScope.compilerOptions().enablePreviewFeatures)) {
if (!(i == nullPatternCount && e instanceof TypePattern))
currentScope.problemReporter().IllegalFallThroughToPattern(e);
}
flowInfo = analyseConstantExpression(currentScope, flowContext, flowInfo, e);
if (nullPatternCount > 0 && e instanceof TypePattern) {
LocalVariableBinding binding = ((TypePattern) e).local.binding;
if (binding != null)
flowInfo.markNullStatus(binding, FlowInfo.POTENTIALLY_NULL);
}
}

return flowInfo;
}
private FlowInfo analyseConstantExpression(
Expand Down Expand Up @@ -125,7 +117,7 @@ private FlowInfo analyseConstantExpression(
@Override
public StringBuilder printStatement(int tab, StringBuilder output) {
printIndent(tab, output);
if (this.constantExpressions == null) {
if (this.constantExpressions == Expression.NO_EXPRESSIONS) {
output.append("default "); //$NON-NLS-1$
output.append(this.isExpr ? "->" : ":"); //$NON-NLS-1$ //$NON-NLS-2$
} else {
Expand Down Expand Up @@ -155,24 +147,38 @@ public void generateCode(BlockScope currentScope, CodeStream codeStream) {
}
if (this.targetLabel != null)
this.targetLabel.place();
casePatternExpressionGenerateCode(currentScope, codeStream);
codeStream.recordPositionsFrom(pc, this.sourceStart);
}

private void casePatternExpressionGenerateCode(BlockScope currentScope, CodeStream codeStream) {
if (this.patternIndex != -1) {
Pattern pattern = ((Pattern) this.constantExpressions[this.patternIndex]);
if (containsPatternVariable()) {
this.trueLabel = new BranchLabel(codeStream);
this.falseLabel = new BranchLabel(codeStream);
LocalVariableBinding local = currentScope.findVariable(SwitchStatement.SecretPatternVariableName, null);
codeStream.load(local);
pattern.generateCode(currentScope, codeStream, this.trueLabel, this.falseLabel);
// Srikanth, check this goto.
if (!(pattern instanceof GuardedPattern))
codeStream.goto_(this.trueLabel);
if (containsPatternVariable(true)) {

BranchLabel patternMatchLabel = new BranchLabel(codeStream);
BranchLabel matchFailLabel = new BranchLabel(codeStream);

Pattern pattern = (Pattern) this.constantExpressions[0];
codeStream.load(this.swich.dispatchPatternCopy);
pattern.generateCode(currentScope, codeStream, patternMatchLabel, matchFailLabel);
codeStream.goto_(patternMatchLabel);
matchFailLabel.place();

if (pattern.matchFailurePossible()) {
/* We are generating a "thunk"/"trampoline" of sorts now, that flow analysis has no clue about.
We need to manage the live variables manually. Pattern bindings are not definitely
assigned here as we are in the else region.
*/
final LocalVariableBinding[] bindingsWhenTrue = pattern.bindingsWhenTrue();
Stream.of(bindingsWhenTrue).forEach(v->v.recordInitializationEndPC(codeStream.position));
int caseIndex = this.typeSwitchIndex + pattern.getAlternatives().length;
codeStream.loadInt(this.swich.nullProcessed ? caseIndex - 1 : caseIndex);
codeStream.store(this.swich.restartIndexLocal, false);
codeStream.goto_(this.swich.switchPatternRestartTarget);
Stream.of(bindingsWhenTrue).forEach(v->v.recordInitializationStartPC(codeStream.position));
}
patternMatchLabel.place();
} else {
if (this.swich.containsNull) {
this.swich.nullProcessed |= true;
}
}
codeStream.recordPositionsFrom(pc, this.sourceStart);
}

/**
Expand Down Expand Up @@ -249,10 +255,6 @@ private Expression getFirstValidExpression(BlockScope scope, SwitchStatement swi
if (e instanceof Pattern) {
scope.problemReporter().validateJavaFeatureSupport(JavaFeature.PATTERN_MATCHING_IN_SWITCH,
e.sourceStart, e.sourceEnd);
if (this.constantExpressions.length > 1) {
scope.problemReporter().illegalCaseConstantCombination(e);
return e;
}
} else if (e instanceof NullLiteral) {
scope.problemReporter().validateJavaFeatureSupport(JavaFeature.PATTERN_MATCHING_IN_SWITCH,
e.sourceStart, e.sourceEnd);
Expand Down Expand Up @@ -288,9 +290,9 @@ private Expression getFirstValidExpression(BlockScope scope, SwitchStatement swi
* Returns the constant intValue or ordinal for enum constants. If constant is NotAConstant, then answers Float.MIN_VALUE
*/
public ResolvedCase[] resolveCase(BlockScope scope, TypeBinding switchExpressionType, SwitchStatement switchStatement) {
// switchExpressionType maybe null in error case
this.swich = switchStatement;
scope.enclosingCase = this; // record entering in a switch case block
if (this.constantExpressions == null) {
if (this.constantExpressions == Expression.NO_EXPRESSIONS) {
flagDuplicateDefault(scope, switchStatement, this);
return ResolvedCase.UnresolvedCase;
}
Expand All @@ -315,22 +317,27 @@ public ResolvedCase[] resolveCase(BlockScope scope, TypeBinding switchExpression

if (caseType == null || switchExpressionType == null)
return ResolvedCase.UnresolvedCase;
// Avoid further resolution and secondary errors

if (caseType.isValidBinding()) {
Constant con = resolveConstantExpression(scope, caseType, switchExpressionType, switchStatement, e, cases);
if (con != Constant.NotAConstant) {
int index = this == switchStatement.nullCase && e instanceof NullLiteral ?
-1 : switchStatement.constantIndex++;
cases.add(new ResolvedCase(con, e, caseType, index, false));
if (e instanceof Pattern) {
for (Pattern p : ((Pattern) e).getAlternatives()) {
Constant con = resolveConstantExpression(scope, p.resolvedType, switchExpressionType, switchStatement, p);
if (con != Constant.NotAConstant) {
int index = switchStatement.constantIndex++;
cases.add(new ResolvedCase(con, p, p.resolvedType, index, false));
}
}
} else {
Constant con = resolveConstantExpression(scope, caseType, switchExpressionType, switchStatement, e, cases);
if (con != Constant.NotAConstant) {
int index = this == switchStatement.nullCase && e instanceof NullLiteral ?
-1 : switchStatement.constantIndex++;
cases.add(new ResolvedCase(con, e, caseType, index, false));
}
}
}
}
this.resolveWithBindings(this.bindingsWhenTrue(), scope);
if (cases.size() > 0) {
return cases.toArray(new ResolvedCase[cases.size()]);
}

return ResolvedCase.UnresolvedCase;
return cases.toArray(new ResolvedCase[cases.size()]);
}

private void flagDuplicateDefault(BlockScope scope, SwitchStatement switchStatement, ASTNode node) {
Expand All @@ -344,16 +351,16 @@ private void flagDuplicateDefault(BlockScope scope, SwitchStatement switchStatem
scope.problemReporter().illegalTotalPatternWithDefault(this);
}
}

@Override
public LocalVariableBinding[] bindingsWhenTrue() {
LocalVariableBinding [] variables = NO_VARIABLES;
if (this.constantExpressions != null) {
for (Expression e : this.constantExpressions) {
variables = LocalVariableBinding.merge(variables, e.bindingsWhenTrue());
}
for (Expression e : this.constantExpressions) {
variables = LocalVariableBinding.merge(variables, e.bindingsWhenTrue());
}
return variables;
}

public Constant resolveConstantExpression(BlockScope scope,
TypeBinding caseType,
TypeBinding switchType,
Expand Down Expand Up @@ -382,7 +389,8 @@ public Constant resolveConstantExpression(BlockScope scope,
return Constant.NotAConstant;
}
}
}

}
}
boolean boxing = !patternSwitchAllowed ||
switchStatement.isAllowedType(switchType);
Expand Down Expand Up @@ -466,7 +474,7 @@ private Constant resolveConstantExpression(BlockScope scope,
return IntConstant.fromValue(-1);
}
switchStatement.switchBits |= SwitchStatement.Exhaustive;
if (e.isAlwaysTrue()) {
if (e.isUnconditional(expressionType)) {
switchStatement.switchBits |= SwitchStatement.TotalPattern;
if (switchStatement.defaultCase != null && !(e instanceof RecordPattern))
scope.problemReporter().illegalTotalPatternWithDefault(this);
Expand All @@ -483,13 +491,10 @@ private Constant resolveConstantExpression(BlockScope scope,
@Override
public void traverse(ASTVisitor visitor, BlockScope blockScope) {
if (visitor.visit(this, blockScope)) {
if (this.constantExpressions != null) {
for (Expression e : this.constantExpressions) {
e.traverse(visitor, blockScope);
}
for (Expression e : this.constantExpressions) {
e.traverse(visitor, blockScope);
}

}
visitor.endVisit(this, blockScope);
}
}
}

0 comments on commit cb6774f

Please sign in to comment.