Skip to content

Commit

Permalink
HHH-14704 Fix set operations support when fetches are involved
Browse files Browse the repository at this point in the history
  • Loading branch information
beikov committed Jul 2, 2021
1 parent 174b230 commit 16db356
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 15 deletions.
Expand Up @@ -301,7 +301,9 @@ public SqmStatement<R> visitStatement(HqlParser.StatementContext ctx) {

try {
if ( ctx.selectStatement() != null ) {
return visitSelectStatement( ctx.selectStatement() );
final SqmSelectStatement<R> selectStatement = visitSelectStatement( ctx.selectStatement() );
selectStatement.getQueryPart().validateQueryGroupFetchStructure();
return selectStatement;
}
else if ( ctx.insertStatement() != null ) {
return visitInsertStatement( ctx.insertStatement() );
Expand Down Expand Up @@ -607,18 +609,6 @@ public SqmQueryGroup<Object> visitSetQueryGroup(HqlParser.SetQueryGroupContext c
finally {
processingStateStack.pop();
}
final List<SqmSelection> selections = queryPart.getFirstQuerySpec().getSelectClause().getSelections();
if ( firstSelectionSize != selections.size() ) {
throw new SemanticException( "All query parts must have the same arity!" );
}
for ( int j = 0; j < firstSelectionSize; j++ ) {
final JavaTypeDescriptor<?> firstJavaTypeDescriptor = firstSelections.get( j ).getNodeJavaTypeDescriptor();
if ( firstJavaTypeDescriptor != selections.get( j ).getNodeJavaTypeDescriptor() ) {
throw new SemanticException(
"Select items of the same index must have the same java type across all query parts!"
);
}
}
}
processingStateStack.push( firstProcessingState );

Expand Down
Expand Up @@ -200,8 +200,11 @@ public QuerySqmImpl(

if ( resultType != null ) {
SqmUtil.verifyIsSelectStatement( sqmStatement );
final SqmQueryPart<R> queryPart = ( (SqmSelectStatement<R>) sqmStatement ).getQueryPart();
// For criteria queries, we have to validate the fetch structure here
queryPart.validateQueryGroupFetchStructure();
visitQueryReturnType(
( (SqmSelectStatement<R>) sqmStatement ).getQueryPart(),
queryPart,
resultType,
producer.getFactory()
);
Expand Down
Expand Up @@ -1577,7 +1577,7 @@ public Void visitSelection(SqmSelection<?> sqmSelection) {
final Stack<SqlAstProcessingState> processingStateStack = getProcessingStateStack();
final boolean needsDomainResults = domainResults != null && currentClauseContributesToTopLevelSelectClause();
final boolean collectDomainResults;
if ( processingStateStack.depth() == 1) {
if ( processingStateStack.depth() == 1 ) {
collectDomainResults = needsDomainResults;
}
else {
Expand Down Expand Up @@ -1606,6 +1606,11 @@ public Void visitSelection(SqmSelection<?> sqmSelection) {
if ( collectDomainResults ) {
resultProducers.forEach( (alias, r) -> domainResults.add( r.createDomainResult( alias, this ) ) );
}
else if ( needsDomainResults ) {
// We just create domain results for the purpose of creating selections
// This is necessary for top-level query specs within query groups to avoid cycles
resultProducers.forEach( (alias, r) -> r.createDomainResult( alias, this ) );
}
else {
resultProducers.forEach( (alias, r) -> r.applySqlSelections( this ) );
}
Expand Down
Expand Up @@ -7,16 +7,22 @@
package org.hibernate.query.sqm.tree.select;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;

import org.hibernate.query.FetchClauseType;
import org.hibernate.query.SemanticException;
import org.hibernate.query.SetOperator;
import org.hibernate.internal.util.collections.CollectionHelper;
import org.hibernate.query.criteria.JpaExpression;
import org.hibernate.query.criteria.JpaOrder;
import org.hibernate.query.criteria.JpaQueryGroup;
import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SemanticQueryWalker;
import org.hibernate.query.sqm.tree.from.SqmAttributeJoin;
import org.hibernate.query.sqm.tree.from.SqmFrom;
import org.hibernate.query.sqm.tree.from.SqmJoin;
import org.hibernate.type.descriptor.java.JavaTypeDescriptor;

/**
* A grouped list of queries connected through a certain set operator.
Expand Down Expand Up @@ -104,6 +110,85 @@ public SqmQueryGroup<T> setFetch(JpaExpression<?> fetch, FetchClauseType fetchCl
return (SqmQueryGroup<T>) super.setFetch( fetch, fetchClauseType );
}

@Override
public void validateQueryGroupFetchStructure() {
validateQueryGroupFetchStructure( getFirstQuerySpec() );
}

private void validateQueryGroupFetchStructure(SqmQuerySpec<?> firstQuerySpec) {
final List<SqmSelection> firstSelections = firstQuerySpec.getSelectClause().getSelections();
final int firstSelectionSize = firstSelections.size();
for ( int i = 0; i < queryParts.size(); i++ ) {
final SqmQueryPart<T> queryPart = queryParts.get( i );
if ( queryPart instanceof SqmQueryGroup<?> ) {
( (SqmQueryGroup<Object>) queryPart ).validateQueryGroupFetchStructure( firstQuerySpec );
}
else {
final SqmQuerySpec<?> querySpec = (SqmQuerySpec<?>) queryPart;
final List<SqmSelection> selections = querySpec.getSelectClause().getSelections();
if ( firstSelectionSize != selections.size() ) {
throw new SemanticException( "All query parts in a query group must have the same arity!" );
}
for ( int j = 0; j < firstSelectionSize; j++ ) {
final SqmSelection firstSqmSelection = firstSelections.get( j );
final JavaTypeDescriptor<?> firstJavaTypeDescriptor = firstSqmSelection.getNodeJavaTypeDescriptor();
if ( firstJavaTypeDescriptor != selections.get( j ).getNodeJavaTypeDescriptor() ) {
throw new SemanticException(
"Select items of the same index must have the same java type across all query parts!"
);
}
if ( firstSqmSelection.getSelectableNode() instanceof SqmFrom<?, ?> ) {
final SqmFrom<?, ?> firstFrom = (SqmFrom<?, ?>) firstSqmSelection.getSelectableNode();
final SqmFrom<?, ?> from = (SqmFrom<?, ?>) selections.get( j ).getSelectableNode();
validateFetchesMatch( firstFrom, from );
}
}
}
}
}

private void validateFetchesMatch(SqmFrom<?, ?> firstFrom, SqmFrom<?, ?> from) {
final Iterator<? extends SqmJoin<?, ?>> firstJoinIter = firstFrom.getSqmJoins().iterator();
final Iterator<? extends SqmJoin<?, ?>> joinIter = from.getSqmJoins().iterator();
while ( firstJoinIter.hasNext() ) {
final SqmJoin<?, ?> firstSqmJoin = firstJoinIter.next();
if ( firstSqmJoin instanceof SqmAttributeJoin<?, ?> ) {
final SqmAttributeJoin<?, ?> firstAttrJoin = (SqmAttributeJoin<?, ?>) firstSqmJoin;
if ( firstAttrJoin.isFetched() ) {
SqmAttributeJoin<?, ?> matchingAttrJoin = null;
while ( joinIter.hasNext() ) {
final SqmJoin<?, ?> sqmJoin = joinIter.next();
if ( sqmJoin instanceof SqmAttributeJoin<?, ?> ) {
final SqmAttributeJoin<?, ?> attrJoin = (SqmAttributeJoin<?, ?>) sqmJoin;
if ( attrJoin.isFetched() ) {
matchingAttrJoin = attrJoin;
break;
}
}
}
if ( matchingAttrJoin == null || firstAttrJoin.getModel() != matchingAttrJoin.getModel() ) {
throw new SemanticException(
"All query parts in a query group must have the same join fetches in the same order!"
);
}
validateFetchesMatch( firstAttrJoin, matchingAttrJoin );
}
}
}
// At this point, the other iterator should only contain non-fetch joins
while ( joinIter.hasNext() ) {
final SqmJoin<?, ?> sqmJoin = joinIter.next();
if ( sqmJoin instanceof SqmAttributeJoin<?, ?> ) {
final SqmAttributeJoin<?, ?> attrJoin = (SqmAttributeJoin<?, ?>) sqmJoin;
if ( attrJoin.isFetched() ) {
throw new SemanticException(
"All query parts in a query group must have the same join fetches in the same order!"
);
}
}
}
}

@Override
public void appendHqlString(StringBuilder sb) {
appendQueryPart( queryParts.get( 0 ), sb );
Expand Down
Expand Up @@ -148,6 +148,8 @@ public JpaQueryPart<T> setFetch(JpaExpression<?> fetch, FetchClauseType fetchCla
return this;
}

public abstract void validateQueryGroupFetchStructure();

public void appendHqlString(StringBuilder sb) {
if ( orderByClause == null ) {
return;
Expand Down
Expand Up @@ -328,6 +328,11 @@ public SqmQuerySpec<T> setFetch(JpaExpression<?> fetch, FetchClauseType fetchCla
return this;
}

@Override
public void validateQueryGroupFetchStructure() {
// No-op
}

@Override
public void appendHqlString(StringBuilder sb) {
if ( selectClause != null ) {
Expand Down
Expand Up @@ -10,15 +10,20 @@

import javax.persistence.Tuple;

import org.hibernate.query.SemanticException;

import org.hibernate.testing.TestForIssue;
import org.hibernate.testing.orm.domain.StandardDomainModel;
import org.hibernate.testing.orm.domain.gambit.EntityOfLists;
import org.hibernate.testing.orm.domain.gambit.EntityWithManyToOneSelfReference;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.ServiceRegistry;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand All @@ -40,6 +45,12 @@ public void createTestData(SessionFactoryScope scope) {
session.save( new EntityOfLists( 1, "first" ) );
session.save( new EntityOfLists( 2, "second" ) );
session.save( new EntityOfLists( 3, "third" ) );
EntityWithManyToOneSelfReference first = new EntityWithManyToOneSelfReference( 1, "first", 123 );
EntityWithManyToOneSelfReference second = new EntityWithManyToOneSelfReference( 2, "second", 123 );
session.save( first );
first.setOther( first );
session.save( second );
second.setOther( second );
}
);
}
Expand All @@ -48,12 +59,113 @@ public void createTestData(SessionFactoryScope scope) {
public void dropTestData(SessionFactoryScope scope) {
scope.inTransaction(
session -> {
session.createQuery( "delete from EntityWithManyToOneSelfReference" ).executeUpdate();
session.createQuery( "delete from EntityOfLists" ).executeUpdate();
session.createQuery( "delete from SimpleEntity" ).executeUpdate();
}
);
}

@Test
@TestForIssue( jiraKey = "HHH-14704")
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsUnion.class)
public void testUnionAllWithManyToOne(SessionFactoryScope scope) {
scope.inSession(
session -> {
List<EntityWithManyToOneSelfReference> list = session.createQuery(
"from EntityWithManyToOneSelfReference e where e.id = 1 " +
"union all " +
"from EntityWithManyToOneSelfReference e where e.id = 2",
EntityWithManyToOneSelfReference.class
).list();
assertThat( list.size(), is( 2 ) );
}
);
}

@Test
@TestForIssue( jiraKey = "HHH-14704")
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsUnion.class)
public void testUnionAllWithManyToOneFetch(SessionFactoryScope scope) {
scope.inSession(
session -> {
List<EntityWithManyToOneSelfReference> list = session.createQuery(
"from EntityWithManyToOneSelfReference e join fetch e.other where e.id = 1 " +
"union all " +
"from EntityWithManyToOneSelfReference e join fetch e.other where e.id = 2",
EntityWithManyToOneSelfReference.class
).list();
assertThat( list.size(), is( 2 ) );
}
);
}

@Test
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsUnion.class)
public void testUnionAllWithManyToOneFetchJustOne(SessionFactoryScope scope) {
scope.inSession(
session -> {
try {
session.createQuery(
"from EntityWithManyToOneSelfReference e join fetch e.other where e.id = 1 " +
"union all " +
"from EntityWithManyToOneSelfReference e where e.id = 2",
EntityWithManyToOneSelfReference.class
);
Assertions.fail( "Expected exception to be thrown!" );
}
catch (Exception e) {
Assertions.assertEquals( IllegalArgumentException.class, e.getClass() );
Assertions.assertEquals( SemanticException.class, e.getCause().getClass() );
}
}
);
}

@Test
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsUnion.class)
public void testUnionAllWithManyToOneFetchDifferentAttributes(SessionFactoryScope scope) {
scope.inSession(
session -> {
try {
session.createQuery(
"from EntityOfLists e join fetch e.listOfOneToMany where e.id = 1 " +
"union all " +
"from EntityOfLists e join fetch e.listOfManyToMany where e.id = 2",
EntityOfLists.class
);
Assertions.fail( "Expected exception to be thrown!" );
}
catch (Exception e) {
Assertions.assertEquals( IllegalArgumentException.class, e.getClass() );
Assertions.assertEquals( SemanticException.class, e.getCause().getClass() );
}
}
);
}

@Test
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsUnion.class)
public void testUnionAllWithManyToOneFetchDifferentOrder(SessionFactoryScope scope) {
scope.inSession(
session -> {
try {
session.createQuery(
"from EntityOfLists e join fetch e.listOfOneToMany join fetch e.listOfManyToMany where e.id = 1 " +
"union all " +
"from EntityOfLists e join fetch e.listOfManyToMany join fetch e.listOfOneToMany where e.id = 2",
EntityOfLists.class
);
Assertions.fail( "Expected exception to be thrown!" );
}
catch (Exception e) {
Assertions.assertEquals( IllegalArgumentException.class, e.getClass() );
Assertions.assertEquals( SemanticException.class, e.getCause().getClass() );
}
}
);
}

@Test
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsUnion.class)
public void testUnionAll(SessionFactoryScope scope) {
Expand Down

0 comments on commit 16db356

Please sign in to comment.