Skip to content

Commit

Permalink
Merge pull request #42708 from nipunayf/fix-42331
Browse files Browse the repository at this point in the history
Provide code actions to fix the invalid access of mutable storage in an isolated function
  • Loading branch information
nipunayf committed May 17, 2024
2 parents 5ddcc9f + 98b8c03 commit 24efc8e
Show file tree
Hide file tree
Showing 28 changed files with 589 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.Token;
import io.ballerina.compiler.syntax.tree.TypedBindingPatternNode;
import io.ballerina.projects.Project;
import io.ballerina.tools.diagnostics.Diagnostic;
import io.ballerina.tools.text.LineRange;
import org.ballerinalang.annotation.JavaSPIService;
import org.ballerinalang.langserver.codeaction.CodeActionNodeValidator;
import org.ballerinalang.langserver.codeaction.CodeActionUtil;
Expand Down Expand Up @@ -59,7 +60,7 @@ public class AddIsolatedQualifierCodeAction implements DiagnosticBasedCodeAction
private static final String DIAGNOSTIC_CODE_3961 = "BCE3961";
private static final String ANONYMOUS_FUNCTION_EXPRESSION = "Anonymous function expression";
private static final Set<String> DIAGNOSTIC_CODES =
Set.of("BCE3946", "BCE3947", "BCE3950", DIAGNOSTIC_CODE_3961);
Set.of("BCE3943", "BCE3946", "BCE3947", "BCE3950", DIAGNOSTIC_CODE_3961);

@Override
public boolean validate(Diagnostic diagnostic,
Expand All @@ -81,7 +82,7 @@ public List<CodeAction> getCodeActions(Diagnostic diagnostic,
if (nonTerminalNode.kind() == SyntaxKind.EXPLICIT_ANONYMOUS_FUNCTION_EXPRESSION) {
ExplicitAnonymousFunctionExpressionNode functionExpressionNode =
(ExplicitAnonymousFunctionExpressionNode) nonTerminalNode;
return getCodeAction(functionExpressionNode.functionKeyword(), ANONYMOUS_FUNCTION_EXPRESSION,
return getCodeAction(functionExpressionNode.functionKeyword().lineRange(), ANONYMOUS_FUNCTION_EXPRESSION,
context.fileUri());
}

Expand All @@ -91,9 +92,13 @@ public List<CodeAction> getCodeActions(Diagnostic diagnostic,
return Collections.emptyList();
}

// Obtain the symbol of the referred function
Optional<Symbol> symbol = getReferredSymbol(context, nonTerminalNode);
if (symbol.isEmpty() || symbol.get().getModule().isEmpty()) {
// Obtain the symbol of the referred symbol
Optional<Symbol> optSymbol = getReferredSymbol(context, nonTerminalNode);
if (optSymbol.isEmpty()) {
return Collections.emptyList();
}
Symbol symbol = optSymbol.get();
if (symbol.getModule().isEmpty()) {
return Collections.emptyList();
}

Expand All @@ -104,21 +109,39 @@ public List<CodeAction> getCodeActions(Diagnostic diagnostic,
}

// Obtain the file path of the referred symbol
Optional<Path> filePath = PathUtil.getFilePathForSymbol(symbol.get(), project.get(), context);
if (filePath.isEmpty() || context.workspace().syntaxTree(filePath.get()).isEmpty()) {
Optional<Path> optFilePath = PathUtil.getFilePathForSymbol(symbol, project.get(), context);
if (optFilePath.isEmpty()) {
return Collections.emptyList();
}
Path filePath = optFilePath.get();
if (context.workspace().syntaxTree(filePath).isEmpty()) {
return Collections.emptyList();
}

// Obtain the node of the referred symbol
Optional<NonTerminalNode> node = CommonUtil.findNode(symbol.get(),
context.workspace().syntaxTree(filePath.get()).get());
if (node.isEmpty() || isUnsupportedSyntaxKind(node.get().kind())) {
Optional<NonTerminalNode> optNode = CommonUtil.findNode(symbol, context.workspace().syntaxTree(filePath).get());
if (optNode.isEmpty()) {
return Collections.emptyList();
}
FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) node.get();

return getCodeAction(functionDefinitionNode.functionKeyword(), symbol.get().getName().orElse(""),
filePath.get().toUri().toString());
NonTerminalNode node = optNode.get();
String symbolName = symbol.getName().orElse("");
String filePathString = filePath.toUri().toString();

return switch (node.kind()) {
case FUNCTION_DEFINITION, OBJECT_METHOD_DEFINITION -> {
FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) node;
yield getCodeAction(functionDefinitionNode.functionKeyword().lineRange(), symbolName, filePathString);
}
case CAPTURE_BINDING_PATTERN -> {
NonTerminalNode parentNode = node.parent();
if (parentNode.kind() != SyntaxKind.TYPED_BINDING_PATTERN) {
yield Collections.emptyList();
}
TypedBindingPatternNode typeNode = (TypedBindingPatternNode) parentNode;
yield getCodeAction(typeNode.typeDescriptor().lineRange(), symbolName, filePathString);
}
default -> Collections.emptyList();
};
}

private static Optional<Symbol> getReferredSymbol(CodeActionContext context, NonTerminalNode node) {
Expand All @@ -138,19 +161,15 @@ private static Optional<Symbol> getReferredSymbol(CodeActionContext context, Non
return context.currentSemanticModel().flatMap(semanticModel -> semanticModel.symbol(node));
}

private static List<CodeAction> getCodeAction(Token functionKeyword, String expressionName, String filePath) {
Position position = PositionUtil.toPosition(functionKeyword.lineRange().startLine());
private static List<CodeAction> getCodeAction(LineRange lineRange, String expressionName, String filePath) {
Position position = PositionUtil.toPosition(lineRange.startLine());
String editText = SyntaxKind.ISOLATED_KEYWORD.stringValue() + " ";
TextEdit textEdit = new TextEdit(new Range(position, position), editText);
String commandTitle = String.format(CommandConstants.MAKE_FUNCTION_ISOLATE, expressionName);
return Collections.singletonList(
CodeActionUtil.createCodeAction(commandTitle, List.of(textEdit), filePath, CodeActionKind.QuickFix));
}

private static boolean isUnsupportedSyntaxKind(SyntaxKind kind) {
return kind != SyntaxKind.FUNCTION_DEFINITION && kind != SyntaxKind.OBJECT_METHOD_DEFINITION;
}

private static boolean hasMultipleDiagnostics(NonTerminalNode node, Diagnostic currentDiagnostic,
List<Diagnostic> diagnostics) {
return diagnostics.stream().anyMatch(diagnostic -> !currentDiagnostic.equals(diagnostic) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,27 @@

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ClassFieldSymbol;
import io.ballerina.compiler.api.symbols.ClassSymbol;
import io.ballerina.compiler.api.symbols.Qualifiable;
import io.ballerina.compiler.api.symbols.Qualifier;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.SymbolKind;
import io.ballerina.compiler.api.symbols.TypeDescKind;
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.VariableSymbol;
import io.ballerina.compiler.syntax.tree.ModulePartNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.ObjectFieldNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.TypedBindingPatternNode;
import io.ballerina.tools.diagnostics.Diagnostic;
import io.ballerina.tools.text.LinePosition;
import org.ballerinalang.annotation.JavaSPIService;
import org.ballerinalang.langserver.codeaction.CodeActionNodeValidator;
import org.ballerinalang.langserver.codeaction.CodeActionUtil;
import org.ballerinalang.langserver.common.constants.CommandConstants;
import org.ballerinalang.langserver.common.utils.CommonUtil;
import org.ballerinalang.langserver.common.utils.PositionUtil;
import org.ballerinalang.langserver.commons.CodeActionContext;
import org.ballerinalang.langserver.commons.codeaction.spi.DiagBasedPositionDetails;
Expand All @@ -48,6 +55,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;

/**
* Code Action for making a variable immutable. This will ensure that the given variable is both final and readonly.
Expand All @@ -58,68 +66,85 @@
public class MakeVariableImmutableCodeAction implements DiagnosticBasedCodeActionProvider {

private static final String NAME = "Make variable immutable";
private static final String DIAGNOSTIC_CODE = "BCE3956";
private static final Set<String> DIAGNOSTIC_CODES = Set.of("BCE3943", "BCE3956");

@Override
public boolean validate(Diagnostic diagnostic, DiagBasedPositionDetails positionDetails,
CodeActionContext context) {
return DIAGNOSTIC_CODE.equals(diagnostic.diagnosticInfo().code())
return DIAGNOSTIC_CODES.contains(diagnostic.diagnosticInfo().code())
&& CodeActionNodeValidator.validate(context.nodeAtRange());
}

@Override
public List<CodeAction> getCodeActions(Diagnostic diagnostic, DiagBasedPositionDetails positionDetails,
CodeActionContext context) {
NonTerminalNode cursorNode = positionDetails.matchedNode();

// The current implementation of the CA only supports object fields
if (cursorNode.kind() != SyntaxKind.OBJECT_FIELD) {
assert false : "This line is unreachable as the diagnostic is only generated for an object field.";
return Collections.emptyList();
}

ObjectFieldNode objectFieldNode = (ObjectFieldNode) cursorNode;
Node typeNode = objectFieldNode.typeName();
List<TextEdit> textEdits = new ArrayList<>();

// Check if the type is final
boolean isFinal = objectFieldNode.qualifierList().stream()
.anyMatch(token -> token.kind().equals(SyntaxKind.FINAL_KEYWORD));
if (!isFinal) {
textEdits.add(getFinalTextEdit(typeNode));
}

// Check if the type is readonly
TypeSymbol typeSymbol, readonlyType;
TypeSymbol readonlyType;
SymbolInfo symbolInfo;
try {
SemanticModel semanticModel = context.currentSemanticModel().orElseThrow();
readonlyType = semanticModel.types().READONLY;
Symbol symbol = semanticModel.symbol(cursorNode).orElseThrow();
typeSymbol = getTypeSymbol(symbol).orElseThrow();
ModulePartNode rootNode = context.currentSyntaxTree().orElseThrow().rootNode();
symbolInfo = getSymbolInfo(positionDetails.matchedNode(), semanticModel, rootNode).orElseThrow();
} catch (RuntimeException e) {
assert false : "This line is unreachable because the semantic model cannot be empty, and the type " +
"symbol does not contain errors.";
return Collections.emptyList();
}
boolean isReadonly = typeSymbol.subtypeOf(readonlyType);
if (!isReadonly) {
textEdits.addAll(getReadonlyTextEdits(typeNode, typeSymbol.typeKind() == TypeDescKind.UNION));

// Check if the type is readonly
boolean generateReadonly = !symbolInfo.skipReadonly() && !symbolInfo.typeSymbol().subtypeOf(readonlyType);
boolean generateFinal = !symbolInfo.isFinal();
if (generateFinal) {
textEdits.add(getFinalTextEdit(symbolInfo.typeNode()));
}
if (generateReadonly) {
textEdits.addAll(getReadonlyTextEdits(symbolInfo.typeNode(),
symbolInfo.typeSymbol().typeKind() == TypeDescKind.UNION));
}

// Generate and return the code action
return Collections.singletonList(CodeActionUtil.createCodeAction(
String.format(CommandConstants.MAKE_VARIABLE_IMMUTABLE, getTitleText(isFinal, isReadonly)),
String.format(CommandConstants.MAKE_VARIABLE_IMMUTABLE, getTitleText(generateFinal, generateReadonly)),
textEdits,
context.fileUri(),
CodeActionKind.QuickFix));
}

private static Optional<TypeSymbol> getTypeSymbol(Symbol symbol) {
if (symbol.kind() == SymbolKind.CLASS_FIELD) {
return Optional.of(((ClassFieldSymbol) symbol).typeDescriptor());
private static Optional<SymbolInfo> getSymbolInfo(Node cursorNode, SemanticModel semanticModel,
ModulePartNode rootNode) {
try {
Symbol symbol = semanticModel.symbol(cursorNode).orElseThrow();
if (symbol.kind() == SymbolKind.CLASS_FIELD) {
ClassFieldSymbol classFieldSymbol = (ClassFieldSymbol) symbol;
return Optional.of(new SymbolInfo(classFieldSymbol.typeDescriptor(),
((ObjectFieldNode) cursorNode).typeName(), classFieldSymbol, false));
}
VariableSymbol variableSymbol = (VariableSymbol) symbol;
TypedBindingPatternNode typeNode =
(TypedBindingPatternNode) CommonUtil.findNode(symbol, rootNode.syntaxTree()).orElseThrow().parent();

// Skip the readonly consideration for isolated and readonly classes.
boolean skipReadonly = false;
TypeSymbol typeSymbol = variableSymbol.typeDescriptor();
if (typeSymbol.typeKind() == TypeDescKind.TYPE_REFERENCE) {
Symbol definition = ((TypeReferenceTypeSymbol) typeSymbol).definition();
if (definition.kind() == SymbolKind.CLASS) {
boolean isSupportedClassType = ((ClassSymbol) definition).qualifiers().stream()
.anyMatch(qualifier -> qualifier == Qualifier.READONLY ||
qualifier == Qualifier.ISOLATED);
if (!isSupportedClassType) {
return Optional.empty();
}
skipReadonly = true;
}
}
return Optional.of(new SymbolInfo(typeSymbol, typeNode.typeDescriptor(), variableSymbol, skipReadonly));
} catch (RuntimeException e) {
assert false : "Unconsidered symbol type found";
return Optional.empty();
}
assert false : "Unconsidered symbol type found: " + symbol.kind();
return Optional.empty();
}

private static TextEdit getFinalTextEdit(Node typeNode) {
Expand Down Expand Up @@ -150,14 +175,14 @@ private static List<TextEdit> getReadonlyTextEdits(Node typeNode, boolean isUnio
return textEdits;
}

private static String getTitleText(boolean isFinal, boolean isReadonly) {
private static String getTitleText(boolean generateFinal, boolean generateReadonly) {
StringBuilder result = new StringBuilder();

if (!isFinal) {
if (generateFinal) {
result.append("'").append(SyntaxKind.FINAL_KEYWORD.stringValue()).append("'");
}

if (!isReadonly) {
if (generateReadonly) {
if (result.length() > 0) {
result.append(" and ");
}
Expand All @@ -171,4 +196,13 @@ private static String getTitleText(boolean isFinal, boolean isReadonly) {
public String getName() {
return NAME;
}

private record SymbolInfo(TypeSymbol typeSymbol, Node typeNode, boolean isFinal, boolean skipReadonly) {

public SymbolInfo(TypeSymbol typeSymbol, Node typeNode, Qualifiable qualifiable, boolean skipReadonly) {
this(typeSymbol, typeNode,
qualifiable.qualifiers().stream().anyMatch(qualifier -> qualifier.equals(Qualifier.FINAL)),
skipReadonly);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public Object[][] dataProvider() {
{"add_isolated_qualifier_config11.json"},
{"add_isolated_qualifier_config12.json"},
{"add_isolated_qualifier_config13.json"},
{"add_isolated_qualifier_config14.json"},
{"add_isolated_qualifier_config15.json"},
{"add_isolated_qualifier_config16.json"},
{"add_isolated_qualifier_config17.json"},
{"add_isolated_qualifier_config18.json"},
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ public Object[][] dataProvider() {
{"make_variable_immutable12.json"},
{"make_variable_immutable13.json"},
{"make_variable_immutable14.json"},
{"make_variable_immutable15.json"}
{"make_variable_immutable15.json"},
{"make_variable_immutable16.json"},
{"make_variable_immutable17.json"},
{"make_variable_immutable18.json"},
{"make_variable_immutable19.json"},
{"make_variable_immutable20.json"},
{"make_variable_immutable21.json"},
{"make_variable_immutable22.json"},
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
{
"range": {
"start": {
"line": 11,
"line": 15,
"character": 4
},
"end": {
"line": 11,
"line": 15,
"character": 4
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
{
"range": {
"start": {
"line": 11,
"line": 15,
"character": 4
},
"end": {
"line": 11,
"line": 15,
"character": 4
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
{
"range": {
"start": {
"line": 15,
"line": 19,
"character": 4
},
"end": {
"line": 15,
"line": 19,
"character": 4
}
},
Expand Down
Loading

0 comments on commit 24efc8e

Please sign in to comment.