Skip to content

Commit

Permalink
GH-3954 fix for query regression with group by (#3955)
Browse files Browse the repository at this point in the history
* GH-3954 add tests

* GH-3954 recursively check if projected variable is effectively an aggregate, a constant og a variable declared in the group by expression

* GH-3954 removed tests for infinite recursion
  • Loading branch information
hmottestad committed Jun 5, 2022
1 parent 42e4da8 commit e9bba90
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import org.eclipse.rdf4j.query.algebra.helpers.AbstractQueryModelVisitor;
import org.eclipse.rdf4j.query.algebra.helpers.StatementPatternCollector;
import org.eclipse.rdf4j.query.algebra.helpers.TupleExprs;
import org.eclipse.rdf4j.query.algebra.helpers.VarNameCollector;
import org.eclipse.rdf4j.query.impl.ListBindingSet;
import org.eclipse.rdf4j.query.parser.sparql.ast.ASTAbs;
import org.eclipse.rdf4j.query.parser.sparql.ast.ASTAnd;
Expand Down Expand Up @@ -227,7 +228,6 @@
* A SPARQL AST visitor implementation that creates a query algebra representation of the query.
*
* @author Arjohn Kampman
*
* @apiNote This feature is for internal use only: its existence, signature or behavior may change without warning from
* one release to the next.
*/
Expand Down Expand Up @@ -274,8 +274,7 @@ protected Var mapValueExprToVar(Object valueExpr) {
if (valueExpr instanceof Var) {
return (Var) valueExpr;
} else if (valueExpr instanceof ValueConstant) {
Var v = TupleExprs.createConstVar(((ValueConstant) valueExpr).getValue());
return v;
return TupleExprs.createConstVar(((ValueConstant) valueExpr).getValue());
} else if (valueExpr instanceof TripleRef) {
return ((TripleRef) valueExpr).getExprVar();
} else if (valueExpr == null) {
Expand Down Expand Up @@ -348,7 +347,7 @@ public TupleExpr visit(ASTQueryContainer node, Object data) throws VisitorExcept

@Override
public TupleExpr visit(ASTSelectQuery node, Object data) throws VisitorException {
GraphPattern parentGP = graphPattern;
final GraphPattern parentGP = graphPattern;

// Start with building the graph pattern
graphPattern = new GraphPattern(parentGP);
Expand Down Expand Up @@ -413,11 +412,8 @@ public TupleExpr visit(ASTSelectQuery node, Object data) throws VisitorException
tupleExpr = new Slice(tupleExpr, offset, limit);
}

if (parentGP != null) {

parentGP.addRequiredTE(tupleExpr);
graphPattern = parentGP;
}
parentGP.addRequiredTE(tupleExpr);
graphPattern = parentGP;
return tupleExpr;
}

Expand Down Expand Up @@ -554,12 +550,11 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException {
throw new VisitorException("Either TripleRef or Expression expected in projection.");
}

String targetName = alias;
String sourceName = alias;
if (child instanceof ASTVar) {
sourceName = ((ASTVar) child).getName();
}
ProjectionElem elem = new ProjectionElem(sourceName, targetName);
ProjectionElem elem = new ProjectionElem(sourceName, alias);
projElemList.addElement(elem);

AggregateCollector collector = new AggregateCollector();
Expand Down Expand Up @@ -633,36 +628,39 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException {

result = new Projection(result, projElemList);
if (group != null) {
for (ProjectionElem elem : projElemList.getElements()) {
Set<String> groupNames = group.getBindingNames();
List<ProjectionElem> elements = projElemList.getElements();

for (ProjectionElem elem : elements) {
if (!elem.hasAggregateOperatorInExpression()) {
// non-aggregate projection elem is only allowed to be a constant or a simple expression (see
// https://www.w3.org/TR/sparql11-query/#aggregateRestrictions)
ExtensionElem extElem = elem.getSourceExpression();
if (extElem != null) {
ValueExpr expr = extElem.getExpr();
if (!(expr instanceof ValueConstant)) {
throw new VisitorException(
"non-aggregate expression '" + expr
+ "' not allowed in projection when using GROUP BY.");
if (isIllegalCombinedWithGroupByExpression(expr, elements, groupNames)) {
throw new VisitorException("non-aggregate expression '" + expr
+ "' not allowed in projection when using GROUP BY.");
}

} else {
Set<String> groupNames = group.getBindingNames();

if (!elem.getSourceName().equals(elem.getTargetName())) {
// projection element is a SELECT expression using a simple var (e.g. (?a AS ?b)).
// Projection element is a SELECT expression using a simple var (e.g. (?a AS ?b)).
// Source var must be present in GROUP BY.
if (!groupNames.contains(elem.getSourceName())) {
throw new VisitorException(
"variable '" + elem.getSourceName()
+ "' in projection not present in GROUP BY.");
if (isIllegalCombinedWithGroupByExpression(elem.getSourceName(), elements,
groupNames)) {
throw new VisitorException("variable '" + elem.getSourceName()
+ "' in projection not present in GROUP BY.");
}

}
} else {
// projection element is simple var. Must be present in GROUP BY.
if (!groupNames.contains(elem.getTargetName())) {
throw new VisitorException(
"variable '" + elem.getTargetName()
+ "' in projection not present in GROUP BY.");
throw new VisitorException("variable '" + elem.getTargetName()
+ "' in projection not present in GROUP BY.");
}
}
}
Expand All @@ -686,7 +684,59 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException {
return result;
}

private class GroupFinder extends AbstractQueryModelVisitor<VisitorException> {
private static boolean isIllegalCombinedWithGroupByExpression(ValueExpr expr, List<ProjectionElem> elements,
Set<String> groupNames) {
if (expr instanceof ValueConstant)
return false;

VarNameCollector varNameCollector = new VarNameCollector();
expr.visit(varNameCollector);
Set<String> varNames = varNameCollector.getVarNames();

for (String varName : varNames) {
if (isIllegalCombinedWithGroupByExpression(varName, elements, groupNames)) {
return true;
}
}

return false;
}

private static boolean isIllegalCombinedWithGroupByExpression(String varName, List<ProjectionElem> elements,
Set<String> groupNames) {
do {
String prev = varName;

for (ProjectionElem element : elements) {
if (element.getTargetName().equals(varName)) {
if (element.hasAggregateOperatorInExpression()) {
return false;
} else {
ExtensionElem sourceExpression = element.getSourceExpression();
if (sourceExpression != null) {
if (sourceExpression.getExpr() != null) {
return isIllegalCombinedWithGroupByExpression(sourceExpression.getExpr(), elements,
groupNames);
}
}

varName = element.getSourceName();
break;
}
}
}

// check if we didn't find a new element
if (prev.equals(varName)) {
return true;
}

} while (!groupNames.contains(varName));

return false;
}

private static class GroupFinder extends AbstractQueryModelVisitor<VisitorException> {

private Group group;

Expand Down Expand Up @@ -721,7 +771,7 @@ public TupleExpr visit(ASTConstructQuery node, Object data) throws VisitorExcept
tupleExpr = (TupleExpr) groupNode.jjtAccept(this, tupleExpr);
}

Group group = null;
Group group;
if (tupleExpr instanceof Group) {
group = (Group) tupleExpr;
} else {
Expand Down Expand Up @@ -1083,14 +1133,14 @@ public String visit(ASTGroupCondition node, Object data) throws VisitorException
Group group = (Group) data;
TupleExpr arg = group.getArg();

Extension extension = null;
Extension extension;
if (arg instanceof Extension) {
extension = (Extension) arg;
} else {
extension = new Extension();
}

String name = null;
String name;
ValueExpr ve = castToValueExpr(node.jjtGetChild(0).jjtAccept(this, data));

boolean aliased = false;
Expand Down Expand Up @@ -1535,7 +1585,7 @@ private TupleExpr createTupleExprForNegatedPropertySets(List<PropertySetElem> np
}

private TupleExpr handlePathModifiers(Scope scope, Var subjVar, TupleExpr te, Var endVar, Var contextVar,
long lowerBound, long upperBound) throws VisitorException {
long lowerBound, long upperBound) {

// * and + modifiers
if (upperBound == Long.MAX_VALUE) {
Expand Down Expand Up @@ -1599,11 +1649,11 @@ public Set<SameTerm> getCollectedSameTerms() {

}

private class VarReplacer extends AbstractQueryModelVisitor<VisitorException> {
private static class VarReplacer extends AbstractQueryModelVisitor<VisitorException> {

private Var toBeReplaced;
private final Var toBeReplaced;

private Var replacement;
private final Var replacement;

public VarReplacer(Var toBeReplaced, Var replacement) {
this.toBeReplaced = toBeReplaced;
Expand Down Expand Up @@ -1631,15 +1681,14 @@ public void meet(ProjectionElem node) throws VisitorException {

@Override
public Object visit(ASTPropertyListPath propListNode, Object data) throws VisitorException {
Object subject = data;
Object verbPath = propListNode.getVerb().jjtAccept(this, data);

if (verbPath instanceof Var) {

@SuppressWarnings("unchecked")
List<ValueExpr> objectList = (List<ValueExpr>) propListNode.getObjectList().jjtAccept(this, null);

Var subjVar = mapValueExprToVar(subject);
Var subjVar = mapValueExprToVar(data);

Var predVar = mapValueExprToVar(verbPath);
for (ValueExpr object : objectList) {
Expand All @@ -1652,7 +1701,7 @@ public Object visit(ASTPropertyListPath propListNode, Object data) throws Visito

ASTPropertyListPath nextPropList = propListNode.getNextPropertyList();
if (nextPropList != null) {
nextPropList.jjtAccept(this, subject);
nextPropList.jjtAccept(this, data);
}

return null;
Expand Down Expand Up @@ -2092,9 +2141,7 @@ public BindingSet visit(ASTBindingSet node, Object data) throws VisitorException
}
}

BindingSet result = new ListBindingSet(names, values);

return result;
return new ListBindingSet(names, values);
}

@Override
Expand Down Expand Up @@ -2202,7 +2249,7 @@ public Not visit(ASTNotExistsFunc node, Object data) throws VisitorException {

@Override
public If visit(ASTIf node, Object data) throws VisitorException {
If result = null;
If result;

if (node.jjtGetNumChildren() < 3) {
throw new VisitorException("IF construction missing required number of arguments");
Expand All @@ -2227,7 +2274,7 @@ public ValueExpr visit(ASTInfix node, Object data) throws VisitorException {

@Override
public ValueExpr visit(ASTIn node, Object data) throws VisitorException {
ValueExpr result = null;
ValueExpr result;
ValueExpr leftArg = (ValueExpr) data;
int listItemCount = node.jjtGetNumChildren();

Expand All @@ -2252,7 +2299,7 @@ public ValueExpr visit(ASTIn node, Object data) throws VisitorException {

@Override
public ValueExpr visit(ASTNotIn node, Object data) throws VisitorException {
ValueExpr result = null;
ValueExpr result;
ValueExpr leftArg = (ValueExpr) data;

int listItemCount = node.jjtGetNumChildren();
Expand Down Expand Up @@ -2330,7 +2377,7 @@ public Object visit(ASTBind node, Object data) throws VisitorException {
Extension extension = new Extension();
extension.addElement(new ExtensionElem(ve, alias));

TupleExpr result = null;
TupleExpr result;
TupleExpr arg = graphPattern.buildTupleExpr();

// check if alias is not previously used.
Expand Down Expand Up @@ -2477,7 +2524,7 @@ public Object visit(ASTAvg node, Object data) throws VisitorException {

static class AggregateCollector extends AbstractQueryModelVisitor<VisitorException> {

private Collection<AggregateOperator> operators = new ArrayList<>();
private final Collection<AggregateOperator> operators = new ArrayList<>();

public Collection<AggregateOperator> getOperators() {
return operators;
Expand Down Expand Up @@ -2533,9 +2580,9 @@ private void meetAggregate(AggregateOperator node) {

static class AggregateOperatorReplacer extends AbstractQueryModelVisitor<VisitorException> {

private Var replacement;
private final Var replacement;

private AggregateOperator operator;
private final AggregateOperator operator;

public AggregateOperatorReplacer(AggregateOperator operator, Var replacement) {
this.operator = operator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,11 @@ public class SPARQLParserTest {

private SPARQLParser parser;

/**
* @throws java.lang.Exception
*/
@BeforeEach
public void setUp() throws Exception {
parser = new SPARQLParser();
}

/**
* @throws java.lang.Exception
*/
@AfterEach
public void tearDown() throws Exception {
parser = null;
Expand Down Expand Up @@ -447,7 +441,6 @@ public void testGroupByProjectionHandling_Aggregate_SimpleExpr2() {

// should parse without error
parser.parseQuery(query, null);

}

@Test
Expand All @@ -459,6 +452,47 @@ public void testGroupByProjectionHandling_Aggregate_Constant() {

// should parse without error
parser.parseQuery(query, null);
}

@Test
public void testGroupByProjectionHandling_variableEffectivelyAggregationResult() {
String query = "SELECT (COUNT (*) AS ?count) (?count / ?count AS ?result) (?result AS ?temp) (?temp / 2 AS ?temp2) {\n"
+
" ?s a ?o .\n" +
"}";

// should parse without error
parser.parseQuery(query, null);
}

@Test
public void testGroupByProjectionHandling_effectivelyConstant() {
String query = "SELECT (2 AS ?constant1) (?constant1 AS ?constant2) (?constant2/2 AS ?constant3){\n" +
" ?o ?p ?o .\n" +
"} GROUP BY ?o";

// should parse without error
parser.parseQuery(query, null);
}

@Test
public void testGroupByProjectionHandling_renameVariable() {
String query = "SELECT ?o (?o AS ?o2) (?o2 AS ?o3) (?o3/2 AS ?o4){\n" +
" ?o ?p ?o .\n" +
"} GROUP BY ?o";

// should parse without error
parser.parseQuery(query, null);
}

@Test
public void testGroupByProjectionHandling_renameVariableWithAggregation() {
String query = "SELECT ?o (?o AS ?o2) (COUNT (*) AS ?count) (?o2/?count AS ?newCount){\n" +
" ?o ?p ?o .\n" +
"} GROUP BY ?o";

// should parse without error
parser.parseQuery(query, null);
}

}

0 comments on commit e9bba90

Please sign in to comment.