Skip to content

Commit

Permalink
Merge pull request #41917 from heshanpadmasiri/fix/rec-store-optional
Browse files Browse the repository at this point in the history
Remove unnecessary casting for record field set
  • Loading branch information
LakshanWeerasinghe authored Apr 5, 2024
2 parents 3b1191c + 6551fb8 commit 4bbbae3
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@
import static org.wso2.ballerinalang.compiler.desugar.ASTBuilderUtil.createVariable;
import static org.wso2.ballerinalang.compiler.desugar.ASTBuilderUtil.createVariableRef;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.getMajorVersion;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isAssignmentToOptionalField;
import static org.wso2.ballerinalang.compiler.util.Names.GENERATED_INIT_SUFFIX;
import static org.wso2.ballerinalang.compiler.util.Names.GEN_VAR_PREFIX;
import static org.wso2.ballerinalang.compiler.util.Names.IGNORE;
Expand Down Expand Up @@ -2476,17 +2477,31 @@ private void createSimpleVarDefStmt(BLangSimpleVariable simpleVariable, BLangBlo

@Override
public void visit(BLangAssignment assignNode) {
boolean fieldAccessLVExpr = assignNode.varRef.getKind() == NodeKind.FIELD_BASED_ACCESS_EXPR;
// We rewrite the varRef of the BLangAssignment to a IndexBasedAssignment if it is a FieldBasedAssignment.
// Therefore we must do the shouldWidenExpressionTypeWithNil check before that.
boolean addNilToCastingType = shouldWidenExpressionTypeWithNil(assignNode);
assignNode.varRef = rewriteExpr(assignNode.varRef);
assignNode.expr = rewriteExpr(assignNode.expr);
BType castingType = assignNode.varRef.getBType();
if (fieldAccessLVExpr) {
if (addNilToCastingType) {
castingType = types.addNilForNillableAccessType(castingType);
}
assignNode.expr = types.addConversionExprIfRequired(rewriteExpr(assignNode.expr), castingType);
result = assignNode;
}

private static boolean shouldWidenExpressionTypeWithNil(BLangAssignment assignNode) {
if (!assignNode.expr.getBType().isNullable() || !isAssignmentToOptionalField(assignNode)) {
return false;
}
// If we are assigning to an optional field we have a field based access on a record
BLangFieldBasedAccess fieldAccessNode = (BLangFieldBasedAccess) assignNode.varRef;
BRecordType recordType = (BRecordType) Types.getImpliedType(fieldAccessNode.expr.getBType());
BField field = recordType.fields.get(fieldAccessNode.field.value);
BType fieldType = Types.getImpliedType(field.getType());
return TypeTags.isSimpleBasicType(fieldType.tag);
}

@Override
public void visit(BLangTupleDestructure tupleDestructure) {
// case 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@
import org.wso2.ballerinalang.compiler.tree.expressions.BLangConstant;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangErrorVarRef;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangExpression;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangFieldBasedAccess;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangInvocation;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLambdaFunction;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLetExpression;
Expand Down Expand Up @@ -246,6 +245,7 @@
import static org.ballerinalang.model.tree.NodeKind.RECORD_LITERAL_EXPR;
import static org.ballerinalang.model.tree.NodeKind.REG_EXP_CAPTURING_GROUP;
import static org.ballerinalang.model.tree.NodeKind.REG_EXP_CHARACTER_CLASS;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isAssignmentToOptionalField;

/**
* @since 0.94
Expand Down Expand Up @@ -2318,15 +2318,20 @@ public void visit(BLangAssignment assignNode, AnalyzerData data) {
validateFunctionVarRef(varRef, data);

checkInvalidTypeDef(varRef);
if (varRef.getKind() == NodeKind.FIELD_BASED_ACCESS_EXPR && data.expType.tag != TypeTags.SEMANTIC_ERROR) {
BLangFieldBasedAccess fieldBasedAccessVarRef = (BLangFieldBasedAccess) varRef;
int varRefTypeTag = Types.getImpliedType(fieldBasedAccessVarRef.expr.getBType()).tag;
if (varRefTypeTag == TypeTags.RECORD && Symbols.isOptional(fieldBasedAccessVarRef.symbol)) {
data.expType = types.addNilForNillableAccessType(data.expType);
}
BType actualExpectedType = null;
// For optional field assignments we add nil to the expected type before doing type checking in order to get
// the type in error messages correct. But we don't need an implicit conversion since desugar will add a
// cast if needed.
if (data.expType != symTable.semanticError && isAssignmentToOptionalField(assignNode)) {
actualExpectedType = data.expType;
data.expType = types.addNilForNillableAccessType(actualExpectedType);
}

data.typeChecker.checkExpr(assignNode.expr, data.env, data.expType, data.prevEnvs, data.commonAnalyzerData);
BLangExpression expr = assignNode.expr;
data.typeChecker.checkExpr(expr, data.env, data.expType, data.prevEnvs, data.commonAnalyzerData);
if (actualExpectedType != null && expr.impConversionExpr != null) {
data.typeChecker.resetImpConversionExpr(expr, expr.getBType(), actualExpectedType);
}

validateWorkerAnnAttachments(assignNode.expr, data);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6506,7 +6506,7 @@ protected void visitCheckAndCheckPanicExpr(BLangCheckedExpr checkedExpr, Analyze
data.resultType = types.checkType(checkedExpr, actualType, data.expType);
}

private void resetImpConversionExpr(BLangExpression expr, BType actualType, BType targetType) {
protected void resetImpConversionExpr(BLangExpression expr, BType actualType, BType targetType) {
expr.impConversionExpr = null;
types.setImplicitCastExpr(expr, actualType, targetType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@

import org.ballerinalang.compiler.CompilerOptionName;
import org.ballerinalang.model.elements.PackageID;
import org.ballerinalang.model.tree.NodeKind;
import org.wso2.ballerinalang.compiler.semantics.analyzer.Types;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.Symbols;
import org.wso2.ballerinalang.compiler.semantics.model.types.BField;
import org.wso2.ballerinalang.compiler.semantics.model.types.BRecordType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.tree.BLangFunction;
import org.wso2.ballerinalang.compiler.tree.BLangNode;
import org.wso2.ballerinalang.compiler.tree.BLangSimpleVariable;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangFieldBasedAccess;
import org.wso2.ballerinalang.compiler.tree.statements.BLangAssignment;

import java.util.List;

Expand Down Expand Up @@ -82,4 +90,17 @@ public static boolean isInParameterList(BSymbol symbol, List<BLangSimpleVariable
return false;
}

public static boolean isAssignmentToOptionalField(BLangAssignment assignNode) {
BLangNode varRef = assignNode.varRef;
if (varRef.getKind() != NodeKind.FIELD_BASED_ACCESS_EXPR) {
return false;
}
BLangFieldBasedAccess fieldAccessNode = (BLangFieldBasedAccess) varRef;
BType targetType = Types.getImpliedType(fieldAccessNode.expr.getBType());
if (targetType.tag != TypeTags.RECORD) {
return false;
}
BField field = ((BRecordType) targetType).fields.get(fieldAccessNode.field.value);
return field != null && Symbols.isOptional(field.symbol);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2023, WSO2 LLC. (http://www.wso2.org).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.ballerinalang.test.bir;

import org.ballerinalang.test.BCompileUtil;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import org.wso2.ballerinalang.compiler.bir.emit.BIREmitter;
import org.wso2.ballerinalang.compiler.bir.model.BIRNode;
import org.wso2.ballerinalang.compiler.util.CompilerContext;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;

/**
* This class contains unit tests to validate various desugaring/optimizations on records produce the expected BIR.
*
* @since 2201.9.0
*/
public class RecordDesugarTest {

private BIREmitter birEmitter;
private BCompileUtil.BIRCompileResult result;

@BeforeClass
public void setup() {
birEmitter = BIREmitter.getInstance(new CompilerContext());
result = BCompileUtil.generateBIR("test-src/bir/record_desugar.bal");
}

@Test(description = "Test record field set")
public void testRecordFieldSet() {
List<String> functions = Arrays.asList("setRequiredField", "setNillableField", "setOptionalField");
result.getExpectedBIR().functions.stream().filter(function -> functions.contains(function.name.value))
.forEach(this::assertFunctions);
}

private void assertFunctions(BIRNode.BIRFunction function) {
String actual = BIREmitter.emitFunction(function, 0);
String expected = null;
try {
expected = readFile(function.name.value);
} catch (IOException e) {
Assert.fail("Failed to read the expected BIR file for function: " + function.name.value, e);
}
Assert.assertEquals(actual, expected);
}

private String readFile(String name) throws IOException {
// The files in the bir-dump folder are named with the function name and contain the expected bir dump for
// the function
Path filePath = Paths.get("src", "test", "resources", "test-src", "bir", "bir-dump", name).toAbsolutePath();
if (Files.exists(filePath)) {
StringBuilder contentBuilder = new StringBuilder();

Stream<String> stream = Files.lines(filePath, StandardCharsets.UTF_8);
stream.forEach(s -> contentBuilder.append(s).append("\n"));

return contentBuilder.toString().trim();
}
Assert.fail("Expected BIR file not found for function: " + name);
return null;
}

@AfterClass
public void tearDown() {
result = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ public Object[][] recordOptionalFieldAccessFunctions2() {
{ "testUnavailableFinalAccessInNestedAccess" },
{ "testAvailableFinalAccessInNestedAccess" },
{ "testUnavailableIntermediateAccessInNestedAccess" },
{ "testNilValuedFinalAccessInNestedAccess" }
{ "testNilValuedFinalAccessInNestedAccess" },
{ "testSubtypeAssignment" },
{ "testUnionAssignment" },
{ "testNullableAssignment" }
};
}

Expand Down Expand Up @@ -167,6 +170,20 @@ public void testOptionalFieldAccessOnMethodCall() {
BRunUtil.invoke(result, "testOptionalFieldAccessOnMethodCall");
}

@Test(dataProvider = "optionalFieldRemovalFunctions")
public void testOptionalFieldRemoval(String function) {
BRunUtil.invoke(result, function);
}

@DataProvider(name = "optionalFieldRemovalFunctions")
public Object[][] optionalFieldRemovalFunctions() {
return new Object[][]{
{"testOptionalFieldRemovalBasicType"},
{"testOptionalFieldRemovalIndirect"},
{"testOptionalFieldRemovalComplex"}
};
}

@Test
public void testNestedOptionalFieldAccessOnIntersectionTypes() {
BRunUtil.invoke(result, "testNestedOptionalFieldAccessOnIntersectionTypes");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
setNillableField function() -> () {
%0(RETURN) ();
%1(LOCAL) R2;
%2(TEMP) typeDesc<any|error>;
%4(TEMP) string;
%5(TEMP) int|();
%6(TEMP) int;
%12(TEMP) ();

bb0 {
%2 = newType R2;
%4 = ConstLoad x;
%6 = ConstLoad 1;
%5 = <int|()> %6;
%1 = NewMap %2{%4:%5};
%6 = ConstLoad 2;
%5 = <int|()> %6;
%4 = ConstLoad x;
%1[%4] = %5;
%12 = ConstLoad 0;
%5 = <int|()> %12;
%4 = ConstLoad x;
%1[%4] = %5;
%0 = ConstLoad 0;
GOTO bb1;
}
bb1 {
return;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
setOptionalField function() -> () {
%0(RETURN) ();
%1(LOCAL) R3;
%2(TEMP) typeDesc<any|error>;
%4(TEMP) int;
%6(TEMP) string;
%7(TEMP) int|();
%8(TEMP) ();

bb0 {
%2 = newType R3;
%1 = NewMap %2{};
%4 = ConstLoad 2;
%6 = ConstLoad x;
%1[%6] = %4;
%8 = ConstLoad 0;
%7 = <int|()> %8;
%6 = ConstLoad x;
%1[%6] = %7;
%0 = ConstLoad 0;
GOTO bb1;
}
bb1 {
return;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
setRequiredField function() -> () {
%0(RETURN) ();
%1(LOCAL) R1;
%2(TEMP) typeDesc<any|error{map<ballerina/lang.value:0.0.0:Cloneable>}>;
%4(TEMP) string;
%5(TEMP) int;

bb0 {
%2 = newType R1;
%4 = ConstLoad x;
%5 = ConstLoad 1;
%1 = NewMap %2{%4:%5};
%5 = ConstLoad 2;
%4 = ConstLoad x;
%1[%4] = %5;
%0 = ConstLoad 0;
GOTO bb1;
}
bb1 {
return;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
type R1 record {|
int x;
|};

type R2 record {|
int? x;
|};

type R3 record {|
int x?;
|};

function setRequiredField() {
R1 r1 = {x: 1};
r1.x = 2;
}

function setNillableField() {
R2 r2 = {x: 1};
r2.x = 2;
r2.x = ();
}

function setOptionalField() {
R3 r3 = {};
r3.x = 2;
r3.x = ();
}
Loading

0 comments on commit 4bbbae3

Please sign in to comment.