Skip to content

Commit

Permalink
HHH-17967
Browse files Browse the repository at this point in the history
  - Add test for issue (already fixed on main)
  - Backport the minimal necessary bits of HHH-16931 and
    #7883 to fix the NPE

Signed-off-by: Jan Schatteman <jschatte@redhat.com>
  • Loading branch information
jrenaat authored and beikov committed Apr 25, 2024
1 parent c843573 commit 03e589e
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ protected static <T> RowTransformer<T> determineRowTransformer(
if ( queryOptions.getTupleTransformer() != null ) {
return makeRowTransformerTupleTransformerAdapter( sqm, queryOptions );
}
else if ( resultType == null ) {
else if ( resultType == null || resultType == Object.class ) {
return RowTransformerStandardImpl.instance();
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class AbstractSqmAttributeJoin<O,T>
extends AbstractSqmQualifiedJoin<O,T>
implements SqmAttributeJoin<O,T> {

private final boolean fetched;
private boolean fetched;

public AbstractSqmAttributeJoin(
SqmFrom<?,O> lhs,
Expand Down Expand Up @@ -88,6 +88,10 @@ public boolean isFetched() {
return fetched;
}

public void clearFetched() {
fetched = false;
}

@Override
public <X> X accept(SemanticQueryWalker<X> walker) {
return walker.visitQualifiedAttributeJoin( this );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.hibernate.Internal;
import org.hibernate.metamodel.mapping.ModelPartContainer;
import org.hibernate.metamodel.model.domain.BagPersistentAttribute;
import org.hibernate.metamodel.model.domain.EntityDomainType;
Expand Down Expand Up @@ -254,6 +255,29 @@ public void addSqmJoin(SqmJoin<T, ?> join) {
findRoot().addOrderedJoin( join );
}

@Internal
public void removeLeftFetchJoins() {
if ( joins != null ) {
for ( SqmJoin<T, ?> join : new ArrayList<>(joins) ) {
if ( join instanceof AbstractSqmAttributeJoin ) {
final AbstractSqmAttributeJoin<T, ?> attributeJoin = (AbstractSqmAttributeJoin<T, ?>) join;
if ( attributeJoin.isFetched() ) {
if ( join.getSqmJoinType() == SqmJoinType.LEFT ) {
joins.remove( join );
final List<SqmJoin<?, ?>> orderedJoins = findRoot().getOrderedJoins();
if (orderedJoins != null) {
orderedJoins.remove( join );
}
}
else {
attributeJoin.clearFetched();
}
}
}
}
}
}

@Override
public void visitSqmJoins(Consumer<SqmJoin<T, ?>> consumer) {
if ( joins != null ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public abstract class AbstractSqmSelectQuery<T>
implements SqmSelectQuery<T> {
private final Map<String, SqmCteStatement<?>> cteStatements;
private SqmQueryPart<T> sqmQueryPart;
private Class<T> resultType;
private final Class<T> resultType;

public AbstractSqmSelectQuery(Class<T> resultType, NodeBuilder builder) {
this( new SqmQuerySpec<>( builder ), resultType, builder );
Expand Down Expand Up @@ -202,8 +202,11 @@ public Class<T> getResultType() {
return resultType;
}

/**
* Don't use this method. It has no effect.
*/
protected void setResultType(Class<T> resultType) {
this.resultType = resultType;
// No-op
}

@Override
Expand Down Expand Up @@ -410,7 +413,6 @@ protected Selection<? extends T> getResultSelection(Selection<?>[] selections) {
break;
}
default: {
setResultType( (Class<T>) Object[].class );
resultSelection = ( Selection<? extends T> ) nodeBuilder().array( selections );
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import org.hibernate.query.sqm.tree.expression.ValueBindJpaCriteriaParameter;
import org.hibernate.query.sqm.tree.expression.SqmParameter;
import org.hibernate.query.sqm.tree.from.SqmFromClause;
import org.hibernate.query.sqm.tree.from.SqmRoot;

import static org.hibernate.query.sqm.tree.SqmCopyContext.noParamCopyContext;
import static org.hibernate.query.sqm.tree.jpa.ParameterCollector.collectParameters;

/**
Expand Down Expand Up @@ -119,6 +121,10 @@ public SqmSelectStatement<T> copy(SqmCopyContext context) {
if ( existing != null ) {
return existing;
}
return createCopy( context, getResultType() );
}

private <X> SqmSelectStatement<X> createCopy(SqmCopyContext context, Class<X> resultType) {
final Set<SqmParameter<?>> parameters;
if ( this.parameters == null ) {
parameters = null;
Expand All @@ -129,17 +135,19 @@ public SqmSelectStatement<T> copy(SqmCopyContext context) {
parameters.add( parameter.copy( context ) );
}
}
final SqmSelectStatement<T> statement = context.registerCopy(
//noinspection unchecked
final SqmSelectStatement<X> statement = (SqmSelectStatement<X>) context.registerCopy(
this,
new SqmSelectStatement<>(
nodeBuilder(),
copyCteStatements( context ),
getResultType(),
resultType,
getQuerySource(),
parameters
)
);
statement.setQueryPart( getQueryPart().copy( context ) );
//noinspection unchecked
statement.setQueryPart( (SqmQueryPart<X>) getQueryPart().copy( context ) );
return statement;
}

Expand Down Expand Up @@ -266,9 +274,6 @@ public SqmSelectStatement<T> select(Selection<? extends T> selection) {
checkSelectionIsJpaCompliant( selection );
}
getQuerySpec().setSelection( (JpaSelection<T>) selection );
if ( getResultType() == Object.class ) {
setResultType( (Class<T>) selection.getJavaType() );
}
return this;
}

Expand Down Expand Up @@ -309,7 +314,6 @@ public SqmSelectStatement<T> multiselect(List<Selection<?>> selectionList) {
break;
}
default: {
setResultType( (Class<T>) Object[].class );
resultSelection = ( Selection<? extends T> ) nodeBuilder().array( selections );
}
}
Expand Down Expand Up @@ -460,49 +464,43 @@ private void validateComplianceFetchOffset() {
}

@Override
public JpaCriteriaQuery<Long> createCountQuery() {
final SqmCopyContext context = new NoParamSqmCopyContext() {
@Override
public boolean copyFetchedFlag() {
return false;
public SqmSelectStatement<Long> createCountQuery() {
final SqmSelectStatement<?> copy = createCopy( noParamCopyContext(), Object.class );
final SqmQueryPart<?> queryPart = copy.getQueryPart();
final SqmQuerySpec<?> querySpec;
//TODO: detect queries with no 'group by', but aggregate functions
// in 'select' list (we don't even need to hit the database to
// know they return exactly one row)
if ( queryPart.isSimpleQueryPart()
&& !( querySpec = (SqmQuerySpec<?>) queryPart ).isDistinct()
&& querySpec.getGroupingExpressions().isEmpty() ) {
for ( SqmRoot<?> root : querySpec.getRootList() ) {
root.removeLeftFetchJoins();
}
};
final NodeBuilder nodeBuilder = nodeBuilder();
final Set<SqmParameter<?>> parameters;
if ( this.parameters == null ) {
parameters = null;
querySpec.getSelectClause().setSelection( nodeBuilder().count( new SqmStar( nodeBuilder() )) );
if ( querySpec.getFetch() == null && querySpec.getOffset() == null ) {
querySpec.setOrderByClause( null );
}

return (SqmSelectStatement<Long>) copy;
}
else {
parameters = new LinkedHashSet<>( this.parameters.size() );
for ( SqmParameter<?> parameter : this.parameters ) {
parameters.add( parameter.copy( context ) );
final JpaSelection<?> selection = queryPart.getFirstQuerySpec().getSelection();
if ( selection.isCompoundSelection() ) {
char c = 'a';
for ( JpaSelection<?> item : selection.getSelectionItems() ) {
item.alias( Character.toString( ++c ) + '_' );
}
}
else {
selection.alias( "a_" );
}
final SqmSubQuery<?> subquery = new SqmSubQuery<>( copy, queryPart, null, nodeBuilder() );
final SqmSelectStatement<Long> query = nodeBuilder().createQuery( Long.class );
query.from( subquery );
query.select( nodeBuilder().count( new SqmStar(nodeBuilder())) );
return query;
}
final SqmSelectStatement<Long> selectStatement = new SqmSelectStatement<>(
nodeBuilder,
copyCteStatements( context ),
Long.class,
SqmQuerySource.CRITERIA,
parameters
);
final SqmQuerySpec<Long> querySpec = new SqmQuerySpec<>( nodeBuilder );

final SqmSubQuery<Tuple> subquery = new SqmSubQuery<>( selectStatement, Tuple.class, nodeBuilder );
final SqmQueryPart<T> queryPart = getQueryPart().copy( context );
resetSelections( queryPart );
// Reset the
if ( queryPart.getFetch() == null && queryPart.getOffset() == null ) {
queryPart.setOrderByClause( null );
}
//noinspection unchecked
subquery.setQueryPart( (SqmQueryPart<Tuple>) queryPart );

querySpec.setFromClause( new SqmFromClause( 1 ) );
querySpec.setSelectClause( new SqmSelectClause( false, 1, nodeBuilder ) );
selectStatement.setQueryPart( querySpec );
selectStatement.select( nodeBuilder.count( new SqmStar( nodeBuilder ) ) );
selectStatement.from( subquery );
return selectStatement;
}

private void resetSelections(SqmQueryPart<?> queryPart) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ public SqmSubQuery<T> multiselect(List<Selection<?>> selectionList) {
break;
}
default: {
setResultType( (Class<T>) Object[].class );
resultSelection = ( Selection<? extends T> ) nodeBuilder().array( selections );
}
}
Expand Down Expand Up @@ -609,9 +608,6 @@ public SqmExpressible<T> getNodeType() {
public void applyInferableType(SqmExpressible<?> type) {
//noinspection unchecked
expressibleType = (SqmExpressible<T>) type;
if ( expressibleType != null && expressibleType.getExpressibleJavaType() != null ) {
setResultType( expressibleType.getExpressibleJavaType().getJavaTypeClass() );
}
}

private void applyInferableType(Class<T> type) {
Expand Down

0 comments on commit 03e589e

Please sign in to comment.