Skip to content

Commit

Permalink
Merge pull request #712 from kalaiyarasiganeshalingam/main
Browse files Browse the repository at this point in the history
Fix spread-field error in connection pool config
  • Loading branch information
kalaiyarasiganeshalingam committed Jun 21, 2023
2 parents 3f44efd + df7b952 commit c6904a1
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 64 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- [Add compiler plugin validation to validate spread-field config initialization](https://github.com/ballerina-platform/ballerina-standard-library/issues/4594)

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,22 @@ public void testOptionsWithVariables() {
Assert.assertEquals(availableErrors, 0);
}

@Test
public void negativeTestConnectionPoolWithSpreadField() {
Package currentPackage = loadPackage("sample7");
PackageCompilation compilation = currentPackage.getCompilation();
DiagnosticResult diagnosticResult = compilation.diagnosticResult();
List<Diagnostic> diagnosticErrorStream = diagnosticResult.diagnostics().stream()
.filter(r -> r.diagnosticInfo().severity().equals(DiagnosticSeverity.ERROR))
.collect(Collectors.toList());
long availableErrors = diagnosticErrorStream.size();
Assert.assertEquals(availableErrors, 3);
Assert.assertEquals(diagnosticErrorStream.get(0).diagnosticInfo().messageFormat(),
"invalid value: expected value is greater than one");
Assert.assertEquals(diagnosticErrorStream.get(1).diagnosticInfo().messageFormat(),
"invalid value: expected value is greater than or equal to 30");
Assert.assertEquals(diagnosticErrorStream.get(2).diagnosticInfo().messageFormat(),
"invalid value: expected value is greater than zero");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[package]
org = "oracledb_test"
name = "sample7"
version = "0.1.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2023, WSO2 LLC. (http://www.wso2.com). All Rights Reserved.
//
// This software is the property of WSO2 LLC. and its suppliers, if any.
// Dissemination of any information or reproduction of any material contained
// herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
// You may not alter or remove any copyright or other notice from copies of this content.

import ballerinax/oracledb;
import ballerinax/oracledb.driver as _;

# sql:ConnectionPool parameter record with default optimized values
#
# + maxOpenConnections - The maximum open connections
# + maxConnectionLifeTime - The maximum lifetime of a connection
# + minIdleConnections - The minimum idle time of a connection
type SqlConnectionPoolConfig record {|
int maxOpenConnections = -10;
decimal maxConnectionLifeTime = -180;
int minIdleConnections = -5;
|};

# mysql:Options parameter record with default optimized values
#
# + connectTimeout - Timeout to be used when establishing a connection
type MysqlOptionsConfig record {|
decimal connectTimeout = 10;
|};

# [Configurable] Allocation MySQL Database
#
# + hostname - database hostname
# + username - database username
# + password - database password
# + database - database name
# + port - database port
# + connectionPool - sql:ConnectionPool configurations, type: SqlConnectionPoolConfig
# + mysqlOptions - mysql:Options configurations, type: MysqlOptionsConfig
type AllocationDatabase record {|
string hostname;
string username;
string password;
string database;
int port = 3306;
SqlConnectionPoolConfig connectionPool;
MysqlOptionsConfig mysqlOptions;
|};

configurable AllocationDatabase allocationDatabase = ?;

final oracledb:Client allocationDbClient = check new (
host = allocationDatabase.hostname,
user = allocationDatabase.username,
password = allocationDatabase.password,
port = allocationDatabase.port,
database = allocationDatabase.database,
connectionPool = {
...allocationDatabase.connectionPool
}
);
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,38 @@
package io.ballerina.stdlib.oracledb.compiler;

import io.ballerina.compiler.api.symbols.ModuleSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.SymbolKind;
import io.ballerina.compiler.api.symbols.TypeDescKind;
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.syntax.tree.BasicLiteralNode;
import io.ballerina.compiler.syntax.tree.ChildNodeEntry;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode;
import io.ballerina.compiler.syntax.tree.MappingFieldNode;
import io.ballerina.compiler.syntax.tree.ModulePartNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.SeparatedNodeList;
import io.ballerina.compiler.syntax.tree.NodeList;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.RecordFieldNode;
import io.ballerina.compiler.syntax.tree.RecordFieldWithDefaultValueNode;
import io.ballerina.compiler.syntax.tree.RecordTypeDescriptorNode;
import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.compiler.syntax.tree.SpreadFieldNode;
import io.ballerina.compiler.syntax.tree.TypeDefinitionNode;
import io.ballerina.compiler.syntax.tree.TypedBindingPatternNode;
import io.ballerina.compiler.syntax.tree.UnaryExpressionNode;
import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext;
import io.ballerina.tools.diagnostics.Diagnostic;
import io.ballerina.tools.diagnostics.DiagnosticFactory;
import io.ballerina.tools.diagnostics.DiagnosticInfo;
import io.ballerina.tools.diagnostics.DiagnosticSeverity;
import io.ballerina.tools.diagnostics.Location;

import java.util.List;
import java.util.Optional;

import static io.ballerina.stdlib.oracledb.compiler.Constants.UNNECESSARY_CHARS_REGEX;
Expand Down Expand Up @@ -98,31 +112,45 @@ public static boolean isOracleDBObject(TypeReferenceTypeSymbol typeReference, St
}
}

public static void validateOptions(SyntaxNodeAnalysisContext ctx, MappingConstructorExpressionNode options) {
SeparatedNodeList<MappingFieldNode> fields = options.fields();
for (MappingFieldNode field : fields) {
String name = ((SpecificFieldNode) field).fieldName().toString()
.trim().replaceAll(UNNECESSARY_CHARS_REGEX, "");
ExpressionNode valueNode = ((SpecificFieldNode) field).valueExpr().get();
switch (name) {
case Constants.Options.CONNECT_TIMEOUT:
case Constants.Options.LOGIN_TIMEOUT:
case Constants.Options.SOCKET_TIMEOUT:
float timeoutVal = Float.parseFloat(getTerminalNodeValue(valueNode, "0"));
if (timeoutVal < 0) {
DiagnosticInfo diagnosticInfo = new DiagnosticInfo(ORACLEDB_101.getCode(),
ORACLEDB_101.getMessage(), ORACLEDB_101.getSeverity());
ctx.reportDiagnostic(
DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location()));
public static void validateOptionConfig(SyntaxNodeAnalysisContext ctx, MappingConstructorExpressionNode options) {
for (MappingFieldNode field: options.fields()) {
if (field instanceof SpecificFieldNode) {
SpecificFieldNode specificFieldNode = ((SpecificFieldNode) field);
validateOptions(ctx, specificFieldNode.fieldName().toString().trim().
replaceAll(UNNECESSARY_CHARS_REGEX, ""), specificFieldNode.valueExpr().get());
} else if (field instanceof SpreadFieldNode) {
NodeList<Node> recordFields = Utils.getSpreadFieldType(ctx, ((SpreadFieldNode) field));
for (Node recordField : recordFields) {
if (recordField instanceof RecordFieldWithDefaultValueNode) {
RecordFieldWithDefaultValueNode fieldWithDefaultValueNode =
(RecordFieldWithDefaultValueNode) recordField;
validateOptions(ctx, fieldWithDefaultValueNode.fieldName().toString().
trim().replaceAll(UNNECESSARY_CHARS_REGEX, ""),
fieldWithDefaultValueNode.expression());
}
break;
default:
// Can ignore all the other fields
continue;
}
}
}
}

public static void validateOptions(SyntaxNodeAnalysisContext ctx, String name, ExpressionNode valueNode) {
switch (name) {
case Constants.Options.CONNECT_TIMEOUT:
case Constants.Options.LOGIN_TIMEOUT:
case Constants.Options.SOCKET_TIMEOUT:
float timeoutVal = Float.parseFloat(getTerminalNodeValue(valueNode, "0"));
if (timeoutVal < 0) {
DiagnosticInfo diagnosticInfo = new DiagnosticInfo(ORACLEDB_101.getCode(),
ORACLEDB_101.getMessage(), ORACLEDB_101.getSeverity());
ctx.reportDiagnostic(
DiagnosticFactory.createDiagnostic(diagnosticInfo, valueNode.location()));
}
break;
default:
// Can ignore all the other fields
}
}

public static String getTerminalNodeValue(Node valueNode, String defaultValue) {
String value = defaultValue;
if (valueNode instanceof BasicLiteralNode) {
Expand Down Expand Up @@ -157,4 +185,117 @@ public static DiagnosticInfo addDiagnosticsForInvalidTypes(String objectName, Ty
return null;
}
}

public static NodeList<Node> getSpreadFieldType(SyntaxNodeAnalysisContext ctx, SpreadFieldNode spreadFieldNode) {
List<Symbol> symbols = ctx.semanticModel().moduleSymbols();
Object[] entries = spreadFieldNode.valueExpr().childEntries().toArray();
ModulePartNode modulePartNode = ctx.syntaxTree().rootNode();
ChildNodeEntry type = Utils.getVariableType(symbols, entries, modulePartNode);
RecordTypeDescriptorNode typeDescriptor = Utils.getFirstSpreadFieldRecordTypeDescriptorNode(symbols,
type, modulePartNode);
typeDescriptor = Utils.getEndSpreadFieldRecordType(symbols, entries, modulePartNode,
typeDescriptor);
return typeDescriptor.fields();
}

public static ChildNodeEntry getVariableType(List<Symbol> symbols, Object[] entries,
ModulePartNode modulePartNode) {
for (Symbol symbol : symbols) {
if (!symbol.kind().equals(SymbolKind.VARIABLE)) {
continue;
}
Optional<String> symbolName = symbol.getName();
Optional<Node> childNodeEntry = ((ChildNodeEntry) entries[0]).node();
if (symbolName.isPresent() && childNodeEntry.isPresent() &&
symbolName.get().equals(childNodeEntry.get().toString())) {
Optional<Location> location = symbol.getLocation();
if (location.isPresent()) {
Location loc = location.get();
NonTerminalNode node = modulePartNode.findNode(loc.textRange());
if (node instanceof TypedBindingPatternNode) {
TypedBindingPatternNode typedBindingPatternNode = (TypedBindingPatternNode) node;
return (ChildNodeEntry) typedBindingPatternNode.childEntries().toArray()[0];
}
}
}
}
return null;
}

public static RecordTypeDescriptorNode getFirstSpreadFieldRecordTypeDescriptorNode(List<Symbol> symbols,
ChildNodeEntry type,
ModulePartNode modulePartNode) {
if (type != null && type.node().isPresent()) {
for (Symbol symbol : symbols) {
if (!symbol.kind().equals(SymbolKind.TYPE_DEFINITION)) {
continue;
}
if (symbol.getName().isPresent() &&
symbol.getName().get().equals(type.node().get().toString().trim())) {
Optional<Location> loc = symbol.getLocation();
if (loc.isPresent()) {
Location location = loc.get();
Node node = modulePartNode.findNode(location.textRange());
if (node instanceof TypeDefinitionNode) {
TypeDefinitionNode typeDefinitionNode = (TypeDefinitionNode) node;
return (RecordTypeDescriptorNode) typeDefinitionNode.typeDescriptor();
}
}
}
}
}
return null;
}

public static RecordTypeDescriptorNode getEndSpreadFieldRecordType(List<Symbol> symbols, Object[] entries,
ModulePartNode modulePartNode,
RecordTypeDescriptorNode typeDescriptor) {
if (typeDescriptor != null) {
for (int i = 1; i < entries.length; i++) {
String childNodeEntry = ((ChildNodeEntry) entries[i]).node().get().toString();
NodeList<Node> recordFields = typeDescriptor.fields();
if (childNodeEntry.equals(".")) {
continue;
}
for (Node recordField : recordFields) {
String fieldName;
Node fieldType;
if (recordField instanceof RecordFieldWithDefaultValueNode) {
RecordFieldWithDefaultValueNode fieldWithDefaultValueNode =
(RecordFieldWithDefaultValueNode) recordField;
fieldName = fieldWithDefaultValueNode.fieldName().text().trim();
fieldType = fieldWithDefaultValueNode.typeName();
} else {
RecordFieldNode fieldNode = (RecordFieldNode) recordField;
fieldName = fieldNode.fieldName().text().trim();
fieldType = fieldNode.typeName();
}
if (fieldName.equals(childNodeEntry.trim())) {
if (fieldType instanceof SimpleNameReferenceNode) {
SimpleNameReferenceNode nameReferenceNode = (SimpleNameReferenceNode) fieldType;
for (Symbol symbol : symbols) {
if (!symbol.kind().equals(SymbolKind.TYPE_DEFINITION)) {
continue;
}
if (symbol.getName().isPresent() &&
symbol.getName().get().equals(nameReferenceNode.name().text().trim())) {
Optional<Location> loc = symbol.getLocation();
if (loc.isPresent()) {
Location location = loc.get();
Node node = modulePartNode.findNode(location.textRange());
if (node instanceof TypeDefinitionNode) {
TypeDefinitionNode typeDefinitionNode = (TypeDefinitionNode) node;
typeDescriptor = (RecordTypeDescriptorNode) typeDefinitionNode.
typeDescriptor();
}
}
}
}
}
}
}
}
}
return typeDescriptor;
}
}

0 comments on commit c6904a1

Please sign in to comment.