diff --git a/ci/quarkus.Jenkinsfile b/ci/quarkus.Jenkinsfile index 2ee8dd1091c6..96bdcbf36b72 100644 --- a/ci/quarkus.Jenkinsfile +++ b/ci/quarkus.Jenkinsfile @@ -46,7 +46,7 @@ pipeline { } } dir('quarkus') { - sh "git clone -b 3.15 --single-branch https://github.com/quarkusio/quarkus.git . || git reset --hard && git clean -fx && git pull" + sh "git clone -b remove-dead-code-3.20 --single-branch https://github.com/yrodiere/quarkus.git . || git reset --hard && git clean -fx && git pull" script { def sedStatus = sh (script: "sed -i 's@.*@${env.HIBERNATE_VERSION}@' pom.xml", returnStatus: true) if ( sedStatus != 0 ) { diff --git a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 index b9a113092a4d..9b65eff78379 100644 --- a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 +++ b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 @@ -159,8 +159,7 @@ cycleClause * A toplevel query of subquery, which may be a union or intersection of subqueries */ queryExpression - : withClause? orderedQuery # SimpleQueryGroup - | withClause? orderedQuery (setOperator orderedQuery)+ # SetQueryGroup + : withClause? orderedQuery (setOperator orderedQuery)* ; /** @@ -430,8 +429,6 @@ pathContinuation * * VALUE( path ) * * KEY( path ) * * path[ selector ] - * * ARRAY_GET( embeddableArrayPath, index ).path - * * COALESCE( array1, array2 )[ selector ].path */ syntacticDomainPath : treatedNavigablePath @@ -439,10 +436,6 @@ syntacticDomainPath | mapKeyNavigablePath | simplePath indexedPathAccessFragment | simplePath slicedPathAccessFragment - | toOneFkReference - | function pathContinuation - | function indexedPathAccessFragment pathContinuation? - | function slicedPathAccessFragment ; /** @@ -664,19 +657,21 @@ whereClause predicate //highest to lowest precedence : LEFT_PAREN predicate RIGHT_PAREN # GroupedPredicate - | expression IS NOT? NULL # IsNullPredicate - | expression IS NOT? EMPTY # IsEmptyPredicate - | expression IS NOT? TRUE # IsTruePredicate - | expression IS NOT? FALSE # IsFalsePredicate - | expression IS NOT? DISTINCT FROM expression # IsDistinctFromPredicate + | expression IS NOT? (NULL|EMPTY|TRUE|FALSE) # UnaryIsPredicate | expression NOT? MEMBER OF? path # MemberOfPredicate | expression NOT? IN inList # InPredicate | expression NOT? BETWEEN expression AND expression # BetweenPredicate | expression NOT? (LIKE | ILIKE) expression likeEscape? # LikePredicate - | expression NOT? CONTAINS expression # ContainsPredicate - | expression NOT? INCLUDES expression # IncludesPredicate - | expression NOT? INTERSECTS expression # IntersectsPredicate - | expression comparisonOperator expression # ComparisonPredicate + | expression + ( NOT? (CONTAINS | INCLUDES | INTERSECTS) + | IS NOT? DISTINCT FROM + | EQUAL + | NOT_EQUAL + | GREATER + | GREATER_EQUAL + | LESS + | LESS_EQUAL + ) expression # BinaryExpressionPredicate | EXISTS collectionQuantifier LEFT_PAREN simplePath RIGHT_PAREN # ExistsCollectionPartPredicate | EXISTS expression # ExistsPredicate | NOT predicate # NegatedPredicate @@ -685,18 +680,6 @@ predicate | expression # BooleanExpressionPredicate ; -/** - * An operator which compares values for equality or order - */ -comparisonOperator - : EQUAL - | NOT_EQUAL - | GREATER - | GREATER_EQUAL - | LESS - | LESS_EQUAL - ; - /** * Any right operand of the 'in' operator * @@ -751,7 +734,14 @@ primaryExpression | entityVersionReference # EntityVersionExpression | entityNaturalIdReference # EntityNaturalIdExpression | syntacticDomainPath pathContinuation? # SyntacticPathExpression - | function # FunctionExpression + // ARRAY_GET( embeddableArrayPath, index ).path + // COALESCE( array1, array2 )[ selector ].path + // COALESCE( array1, array2 )[ start : end ] + | function ( + pathContinuation + | slicedPathAccessFragment + | indexedPathAccessFragment pathContinuation? + )? # FunctionExpression | generalPathFragment # GeneralPathExpression ; @@ -1109,6 +1099,7 @@ function | collectionFunctionMisuse | jpaNonstandardFunction | columnFunction + | toOneFkReference | genericFunction ; diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java index 1f33aadcfbbb..a90a177b03d6 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java @@ -75,8 +75,13 @@ private static OrderingParser.OrderByFragmentContext buildParseTree(TranslationC return parser.orderByFragment(); } catch (ParseCancellationException e) { + // When resetting the parser, its CommonTokenStream will seek(0) i.e. restart emitting buffered tokens. + // This is enough when reusing the lexer and parser, and it would be wrong to also reset the lexer. + // Resetting the lexer causes it to hand out tokens again from the start, which will then append to the + // CommonTokenStream and cause a wrong parse + // lexer.reset(); + // reset the input token stream and parser state - lexer.reset(); parser.reset(); // fall back to LL(k)-based parsing diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index 17b152e5642f..38328228ecf3 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -33,6 +33,7 @@ import java.util.Map; import java.util.Set; +import org.antlr.v4.runtime.Token; import org.hibernate.boot.registry.classloading.spi.ClassLoaderService; import org.hibernate.boot.registry.classloading.spi.ClassLoadingException; import org.hibernate.dialect.function.SqlColumn; @@ -807,87 +808,14 @@ public Object visitCte(HqlParser.CteContext ctx) { final JpaCteCriteria oldCte = currentPotentialRecursiveCte; try { currentPotentialRecursiveCte = null; - if ( queryExpressionContext instanceof HqlParser.SetQueryGroupContext ) { - final HqlParser.SetQueryGroupContext setContext = (HqlParser.SetQueryGroupContext) queryExpressionContext; - // A recursive query is only possible if the child count is lower than 5 e.g. `withClause? q1 op q2` - if ( setContext.getChildCount() < 5 ) { - final SetOperator setOperator = (SetOperator) setContext.getChild( setContext.getChildCount() - 2 ) - .accept( this ); - switch ( setOperator ) { - case UNION: - case UNION_ALL: - final HqlParser.OrderedQueryContext nonRecursiveQueryContext; - final HqlParser.OrderedQueryContext recursiveQueryContext; - // On count == 4, we have a withClause at index 0 - if ( setContext.getChildCount() == 4 ) { - nonRecursiveQueryContext = (HqlParser.OrderedQueryContext) setContext.getChild( 1 ); - recursiveQueryContext = (HqlParser.OrderedQueryContext) setContext.getChild( 3 ); - } - else { - nonRecursiveQueryContext = (HqlParser.OrderedQueryContext) setContext.getChild( 0 ); - recursiveQueryContext = (HqlParser.OrderedQueryContext) setContext.getChild( 2 ); - } - // First visit the non-recursive part - nonRecursiveQueryContext.accept( this ); - - // Visiting the possibly recursive part must happen within the call to SqmCteContainer.with, - // because in there, the SqmCteStatement/JpaCteCriteria is available for use in the recursive part. - // The structure (SqmCteTable) for the SqmCteStatement is based on the non-recursive part, - // which is necessary to have, so that the SqmCteRoot/SqmCteJoin can resolve sub-paths. - final SqmSelectStatement recursivePart = new SqmSelectStatement<>( creationContext.getNodeBuilder() ); - - processingStateStack.pop(); - processingStateStack.push( - new SqmQueryPartCreationProcessingStateStandardImpl( - processingStateStack.getCurrent(), - recursivePart, - this - ) - ); - final JpaCteCriteria cteDefinition; - if ( setOperator == SetOperator.UNION ) { - cteDefinition = cteContainer.withRecursiveUnionDistinct( - name, - cte, - cteCriteria -> { - currentPotentialRecursiveCte = cteCriteria; - recursiveQueryContext.accept( this ); - return recursivePart; - } - ); - } - else { - cteDefinition = cteContainer.withRecursiveUnionAll( - name, - cte, - cteCriteria -> { - currentPotentialRecursiveCte = cteCriteria; - recursiveQueryContext.accept( this ); - return recursivePart; - } - ); - } - if ( materialization != null ) { - cteDefinition.setMaterialization( materialization ); - } - final ParseTree lastChild = ctx.getChild( ctx.getChildCount() - 1 ); - final ParseTree potentialSearchClause; - if ( lastChild instanceof HqlParser.CycleClauseContext ) { - applyCycleClause( cteDefinition, (HqlParser.CycleClauseContext) lastChild ); - potentialSearchClause = ctx.getChild( ctx.getChildCount() - 2 ); - } - else { - potentialSearchClause = lastChild; - } - if ( potentialSearchClause instanceof HqlParser.SearchClauseContext ) { - applySearchClause( cteDefinition, (HqlParser.SearchClauseContext) potentialSearchClause ); - } - return null; - } + // A recursive query is only possible if there are 2 ordered queries e.g. `q1 op q2` + if ( queryExpressionContext.orderedQuery().size() == 2 ) { + if ( handleRecursive( ctx, queryExpressionContext, cteContainer, name, cte, materialization ) ) { + return null; } } queryExpressionContext.accept( this ); - final JpaCteCriteria cteDefinition = cteContainer.with( name, cte ); + final JpaCteCriteria cteDefinition = cteContainer.with( name, cte ); if ( materialization != null ) { cteDefinition.setMaterialization( materialization ); } @@ -899,6 +827,76 @@ public Object visitCte(HqlParser.CteContext ctx) { return null; } + private boolean handleRecursive( + HqlParser.CteContext cteContext, + HqlParser.QueryExpressionContext setContext, + SqmCteContainer cteContainer, + String name, + SqmSelectQuery cte, + CteMaterialization materialization) { + final SetOperator setOperator = (SetOperator) setContext.setOperator(0).accept( this ); + switch ( setOperator ) { + case UNION: + case UNION_ALL: + final var nonRecursiveQueryContext = setContext.orderedQuery(0); + final var recursiveQueryContext = setContext.orderedQuery(1); + // First visit the non-recursive part + nonRecursiveQueryContext.accept( this ); + + // Visiting the possibly recursive part must happen within the call to SqmCteContainer.with, + // because in there, the SqmCteStatement/JpaCteCriteria is available for use in the recursive part. + // The structure (SqmCteTable) for the SqmCteStatement is based on the non-recursive part, + // which is necessary to have, so that the SqmCteRoot/SqmCteJoin can resolve sub-paths. + final SqmSelectStatement recursivePart = + new SqmSelectStatement<>( creationContext.getNodeBuilder() ); + + processingStateStack.pop(); + processingStateStack.push( + new SqmQueryPartCreationProcessingStateStandardImpl( + processingStateStack.getCurrent(), + recursivePart, + this + ) + ); + final JpaCteCriteria cteDefinition; + if ( setOperator == SetOperator.UNION ) { + cteDefinition = cteContainer.withRecursiveUnionDistinct( + name, + cte, + cteCriteria -> { + currentPotentialRecursiveCte = cteCriteria; + recursiveQueryContext.accept( this ); + return recursivePart; + } + ); + } + else { + cteDefinition = cteContainer.withRecursiveUnionAll( + name, + cte, + cteCriteria -> { + currentPotentialRecursiveCte = cteCriteria; + recursiveQueryContext.accept( this ); + return recursivePart; + } + ); + } + if ( materialization != null ) { + cteDefinition.setMaterialization( materialization ); + } + final var cycleClauseContext = cteContext.cycleClause(); + if ( cycleClauseContext != null ) { + applyCycleClause( cteDefinition, cycleClauseContext ); + } + final var searchClauseContext = cteContext.searchClause(); + if ( searchClauseContext != null ) { + applySearchClause( cteDefinition, searchClauseContext ); + } + return true; + } + return false; + } + private void applyCycleClause(JpaCteCriteria cteDefinition, HqlParser.CycleClauseContext ctx) { final HqlParser.CteAttributesContext attributesContext = ctx.cteAttributes(); final String cycleMarkAttributeName = visitIdentifier( (HqlParser.IdentifierContext) ctx.getChild( 3 ) ); @@ -1016,15 +1014,6 @@ private void applySearchClause(JpaCteCriteria cteDefinition, HqlParser.Search cteDefinition.search( kind, searchAttributeName, searchOrders ); } - @Override - public SqmQueryPart visitSimpleQueryGroup(HqlParser.SimpleQueryGroupContext ctx) { - final int lastChild = ctx.getChildCount() - 1; - if ( lastChild != 0 ) { - ctx.getChild( 0 ).accept( this ); - } - return (SqmQueryPart) ctx.getChild( lastChild ).accept( this ); - } - @Override public SqmQueryPart visitQueryOrderExpression(HqlParser.QueryOrderExpressionContext ctx) { final SqmQuerySpec sqmQuerySpec = currentQuerySpec(); @@ -1064,37 +1053,41 @@ public SqmQueryPart visitNestedQueryExpression(HqlParser.NestedQueryExpressio } @Override - public SqmQueryGroup visitSetQueryGroup(HqlParser.SetQueryGroupContext ctx) { - final List children = ctx.children; - final int firstIndex; - if ( children.get( 0 ) instanceof HqlParser.WithClauseContext ) { - children.get( 0 ).accept( this ); - firstIndex = 1; + public SqmQueryPart visitQueryExpression(HqlParser.QueryExpressionContext ctx) { + var withClauseContext = ctx.withClause(); + if ( withClauseContext != null ) { + withClauseContext.accept( this ); } - else { - firstIndex = 0; + final var orderedQueryContexts = ctx.orderedQuery(); + final SqmQueryPart firstQueryPart = + (SqmQueryPart) orderedQueryContexts.get( 0 ).accept( this ); + if ( orderedQueryContexts.size() == 1 ) { + return firstQueryPart; } if ( creationOptions.useStrictJpaCompliance() ) { throw new StrictJpaComplianceViolation( StrictJpaComplianceViolation.Type.SET_OPERATIONS ); } - final SqmQueryPart firstQueryPart = (SqmQueryPart) children.get( firstIndex ).accept( this ); SqmQueryGroup queryGroup; - if ( firstQueryPart instanceof SqmQueryGroup) { + if ( firstQueryPart instanceof SqmQueryGroup ) { queryGroup = (SqmQueryGroup) firstQueryPart; } else { queryGroup = new SqmQueryGroup<>( firstQueryPart ); } setCurrentQueryPart( queryGroup ); - final int size = children.size(); + final var setOperatorContexts = ctx.setOperator(); final SqmCreationProcessingState firstProcessingState = processingStateStack.pop(); - for ( int i = firstIndex + 1; i < size; i += 2 ) { - final SetOperator operator = visitSetOperator( (HqlParser.SetOperatorContext) children.get(i) ); - final HqlParser.OrderedQueryContext simpleQueryCtx = - (HqlParser.OrderedQueryContext) children.get( i + 1 ); - queryGroup = getSqmQueryGroup( operator, simpleQueryCtx, queryGroup, size, firstProcessingState, i ); + for ( int i = 0; i < setOperatorContexts.size(); i++ ) { + queryGroup = getSqmQueryGroup( + visitSetOperator( setOperatorContexts.get(i) ), + orderedQueryContexts.get( i + 1 ), + queryGroup, + setOperatorContexts.size(), + firstProcessingState, + i + ); } processingStateStack.push( firstProcessingState ); @@ -1108,8 +1101,6 @@ private SqmQueryGroup getSqmQueryGroup( int size, SqmCreationProcessingState firstProcessingState, int i) { - - final List> queryParts; processingStateStack.push( new SqmQueryPartCreationProcessingStateStandardImpl( processingStateStack.getCurrent(), @@ -1117,7 +1108,9 @@ private SqmQueryGroup getSqmQueryGroup( this ) ); - if ( queryGroup.getSetOperator() == null || queryGroup.getSetOperator() == operator ) { + final List> queryParts; + final SetOperator setOperator = queryGroup.getSetOperator(); + if ( setOperator == null || setOperator == operator ) { queryGroup.setSetOperator( operator ); queryParts = queryGroup.queryParts(); } @@ -1129,15 +1122,14 @@ private SqmQueryGroup getSqmQueryGroup( } try { - final List subChildren = simpleQueryCtx.children; - if ( subChildren.get( 0 ) instanceof HqlParser.QueryContext ) { + if ( simpleQueryCtx instanceof HqlParser.QuerySpecExpressionContext ) { final SqmQuerySpec querySpec = new SqmQuerySpec<>( creationContext.getNodeBuilder() ); queryParts.add( querySpec ); visitQuerySpecExpression( (HqlParser.QuerySpecExpressionContext) simpleQueryCtx ); } - else { + else if ( simpleQueryCtx instanceof HqlParser.NestedQueryExpressionContext ) { try { - final SqmSelectStatement selectStatement = + final SqmSelectStatement selectStatement = new SqmSelectStatement<>( creationContext.getNodeBuilder() ); processingStateStack.push( new SqmQueryPartCreationProcessingStateStandardImpl( @@ -1155,6 +1147,7 @@ private SqmQueryGroup getSqmQueryGroup( processingStateStack.pop(); } } + // else if QueryOrderExpressionContext, nothing to do } finally { processingStateStack.pop(); @@ -1898,8 +1891,34 @@ public Object visitGeneralPathExpression(HqlParser.GeneralPathExpressionContext } @Override - public SqmExpression visitFunctionExpression(HqlParser.FunctionExpressionContext ctx) { - return (SqmExpression) ctx.function().accept( this ); + public Object visitFunctionExpression(HqlParser.FunctionExpressionContext ctx) { + final var slicedFragmentsCtx = ctx.slicedPathAccessFragment(); + if ( slicedFragmentsCtx != null ) { + final List slicedFragments = slicedFragmentsCtx.expression(); + return getFunctionDescriptor( "array_slice" ).generateSqmExpression( + List.of( + (SqmTypedNode) visitFunction( ctx.function() ), + (SqmTypedNode) slicedFragments.get( 0 ).accept( this ), + (SqmTypedNode) slicedFragments.get( 1 ).accept( this ) + ), + null, + creationContext.getQueryEngine() + ); + } + else { + final var function = (SqmExpression) visitFunction( ctx.function() ); + final var indexedPathAccessFragment = ctx.indexedPathAccessFragment(); + final var pathContinuation = ctx.pathContinuation(); + if ( indexedPathAccessFragment == null && pathContinuation == null ) { + return function; + } + else { + return visitPathContinuation( + visitIndexedPathAccessFragment( (SemanticPathPart) function, indexedPathAccessFragment ), + pathContinuation + ); + } + } } @Override @@ -2435,90 +2454,135 @@ public SqmBetweenPredicate visitBetweenPredicate(HqlParser.BetweenPredicateConte ); } - @Override - public SqmNullnessPredicate visitIsNullPredicate(HqlParser.IsNullPredicateContext ctx) { - return new SqmNullnessPredicate( - (SqmExpression) ctx.expression().accept( this ), - ctx.NOT() != null, - creationContext.getNodeBuilder() - ); + public SqmPredicate visitUnaryIsPredicate(HqlParser.UnaryIsPredicateContext ctx) { + final var expression = (SqmExpression) ctx.expression().accept( this ); + final var negated = ctx.NOT() != null; + final var nodeBuilder = creationContext.getNodeBuilder(); + switch ( ((TerminalNode) ctx.getChild( ctx.getChildCount() - 1 )).getSymbol().getType() ) { + case HqlParser.NULL: + return new SqmNullnessPredicate( expression, negated, nodeBuilder ); + case HqlParser.EMPTY: + if ( expression instanceof SqmPluralValuedSimplePath ) { + return new SqmEmptinessPredicate( (SqmPluralValuedSimplePath) expression, negated, nodeBuilder ); + } + else { + throw new SemanticException( "Operand of 'is empty' operator must be a plural path", query ); + } + case HqlParser.TRUE: + return new SqmTruthnessPredicate( expression, true, negated, nodeBuilder ); + case HqlParser.FALSE: + return new SqmTruthnessPredicate( expression, false, negated, nodeBuilder ); + default: + throw new AssertionError( "Unknown unary is predicate: " + ctx.getChild( ctx.getChildCount() - 1 ) ); + } } @Override - public SqmEmptinessPredicate visitIsEmptyPredicate(HqlParser.IsEmptyPredicateContext ctx) { - SqmExpression expression = (SqmExpression) ctx.expression().accept(this); - if ( expression instanceof SqmPluralValuedSimplePath ) { - return new SqmEmptinessPredicate( - (SqmPluralValuedSimplePath) expression, - ctx.NOT() != null, - creationContext.getNodeBuilder() - ); + public SqmPredicate visitBinaryExpressionPredicate(HqlParser.BinaryExpressionPredicateContext ctx) { + final var firstSymbol = ((TerminalNode) ctx.getChild( 1 )).getSymbol(); + final boolean negated; + final Token operationSymbol; + if ( firstSymbol.getType() == HqlParser.NOT ) { + negated = true; + operationSymbol = ((TerminalNode) ctx.getChild( 2 )).getSymbol(); } else { - throw new SemanticException( "Operand of 'is empty' operator must be a plural path", query ); - } - } - - @Override - public Object visitIsTruePredicate(HqlParser.IsTruePredicateContext ctx) { - return new SqmTruthnessPredicate( - (SqmExpression) ctx.expression().accept( this ), - true, - ctx.NOT() != null, - creationContext.getNodeBuilder() - ); - } - - @Override - public Object visitIsFalsePredicate(HqlParser.IsFalsePredicateContext ctx) { - return new SqmTruthnessPredicate( - (SqmExpression) ctx.expression().accept( this ), - false, - ctx.NOT() != null, - creationContext.getNodeBuilder() - ); - } - - @Override - public Object visitComparisonOperator(HqlParser.ComparisonOperatorContext ctx) { - final TerminalNode firstToken = (TerminalNode) ctx.getChild( 0 ); - switch ( firstToken.getSymbol().getType() ) { - case HqlLexer.EQUAL: - return ComparisonOperator.EQUAL; - case HqlLexer.NOT_EQUAL: - return ComparisonOperator.NOT_EQUAL; - case HqlLexer.LESS: - return ComparisonOperator.LESS_THAN; - case HqlLexer.LESS_EQUAL: - return ComparisonOperator.LESS_THAN_OR_EQUAL; - case HqlLexer.GREATER: - return ComparisonOperator.GREATER_THAN; - case HqlLexer.GREATER_EQUAL: - return ComparisonOperator.GREATER_THAN_OR_EQUAL; + negated = firstSymbol.getType() == HqlParser.IS + && ((TerminalNode) ctx.getChild( 2 )).getSymbol().getType() == HqlParser.NOT; + operationSymbol = firstSymbol; + } + final var expressions = ctx.expression(); + final var lhsCtx = expressions.get( 0 ); + final var rhsCtx = expressions.get( 1 ); + switch ( operationSymbol.getType() ) { + case HqlParser.CONTAINS: { + final var lhs = (SqmExpression) lhsCtx.accept( this ); + final var rhs = (SqmExpression) rhsCtx.accept( this ); + final var lhsExpressible = lhs.getExpressible(); + if ( lhsExpressible != null && !(lhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "First operand for contains predicate must be a basic plural type expression, but found: " + lhsExpressible.getSqmType(), + query + ); + } + final SelfRenderingSqmFunction contains = getFunctionDescriptor( + "array_contains" ).generateSqmExpression( + asList( lhs, rhs ), + null, + creationContext.getQueryEngine() + ); + return new SqmBooleanExpressionPredicate( contains, negated, creationContext.getNodeBuilder() ); + } + case HqlParser.INCLUDES: { + final var lhs = (SqmExpression) lhsCtx.accept( this ); + final var rhs = (SqmExpression) rhsCtx.accept( this ); + final var lhsExpressible = lhs.getExpressible(); + final var rhsExpressible = rhs.getExpressible(); + if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "First operand for includes predicate must be a basic plural type expression, but found: " + + lhsExpressible.getSqmType(), + query + ); + } + if ( rhsExpressible != null && !( rhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "Second operand for includes predicate must be a basic plural type expression, but found: " + + rhsExpressible.getSqmType(), + query + ); + } + final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_includes" ).generateSqmExpression( + asList( lhs, rhs ), + null, + creationContext.getQueryEngine() + ); + return new SqmBooleanExpressionPredicate( contains, negated, creationContext.getNodeBuilder() ); + } + case HqlParser.INTERSECTS: { + final var lhs = (SqmExpression) lhsCtx.accept( this ); + final var rhs = (SqmExpression) rhsCtx.accept( this ); + final var lhsExpressible = lhs.getExpressible(); + if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType ) ) { + throw new SemanticException( + "First operand for intersects predicate must be a basic plural type expression, but found: " + + lhsExpressible.getSqmType(), + query + ); + } + final SelfRenderingSqmFunction contains = + getFunctionDescriptor( "array_intersects" ) + .generateSqmExpression( + asList( lhs, rhs ), + null, + creationContext.getQueryEngine() + ); + return new SqmBooleanExpressionPredicate( contains, negated, creationContext.getNodeBuilder() ); + } + case HqlParser.EQUAL: + return createComparisonPredicate( ComparisonOperator.EQUAL, lhsCtx, rhsCtx ); + case HqlParser.NOT_EQUAL: + return createComparisonPredicate( ComparisonOperator.NOT_EQUAL, lhsCtx, rhsCtx ); + case HqlParser.LESS: + return createComparisonPredicate( ComparisonOperator.LESS_THAN, lhsCtx, rhsCtx ); + case HqlParser.LESS_EQUAL: + return createComparisonPredicate( ComparisonOperator.LESS_THAN_OR_EQUAL, lhsCtx, rhsCtx ); + case HqlParser.GREATER: + return createComparisonPredicate( ComparisonOperator.GREATER_THAN, lhsCtx, rhsCtx ); + case HqlParser.GREATER_EQUAL: + return createComparisonPredicate( ComparisonOperator.GREATER_THAN_OR_EQUAL, lhsCtx, rhsCtx ); + case HqlParser.IS: { + final ComparisonOperator comparisonOperator = !negated + ? ComparisonOperator.DISTINCT_FROM + : ComparisonOperator.NOT_DISTINCT_FROM; + return createComparisonPredicate( comparisonOperator, lhsCtx, rhsCtx ); + } default: - throw new ParsingException("Unrecognized comparison operator"); + throw new AssertionError( "Unknown binary expression predicate: " + operationSymbol ); } } - @Override - public SqmPredicate visitComparisonPredicate(HqlParser.ComparisonPredicateContext ctx) { - final ComparisonOperator comparisonOperator = (ComparisonOperator) ctx.comparisonOperator().accept( this ); - final HqlParser.ExpressionContext leftExpressionContext = ctx.expression( 0 ); - final HqlParser.ExpressionContext rightExpressionContext = ctx.expression( 1 ); - return createComparisonPredicate( comparisonOperator, leftExpressionContext, rightExpressionContext ); - } - - @Override - public SqmPredicate visitIsDistinctFromPredicate(HqlParser.IsDistinctFromPredicateContext ctx) { - final HqlParser.ExpressionContext leftExpressionContext = ctx.expression( 0 ); - final HqlParser.ExpressionContext rightExpressionContext = ctx.expression( 1 ); - final ComparisonOperator comparisonOperator = ctx.NOT() == null - ? ComparisonOperator.DISTINCT_FROM - : ComparisonOperator.NOT_DISTINCT_FROM; - return createComparisonPredicate( comparisonOperator, leftExpressionContext, rightExpressionContext ); - } - private SqmComparisonPredicate createComparisonPredicate( ComparisonOperator comparisonOperator, HqlParser.ExpressionContext leftExpressionContext, @@ -2645,73 +2709,6 @@ private String getPossibleEnumValue(HqlParser.ExpressionContext expressionContex return null; } - @Override - public SqmPredicate visitContainsPredicate(HqlParser.ContainsPredicateContext ctx) { - final boolean negated = ctx.NOT() != null; - final SqmExpression lhs = (SqmExpression) ctx.expression( 0 ).accept( this ); - final SqmExpression rhs = (SqmExpression) ctx.expression( 1 ).accept( this ); - final SqmExpressible lhsExpressible = lhs.getExpressible(); - if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "First operand for contains predicate must be a basic plural type expression, but found: " + lhsExpressible.getSqmType(), - query - ); - } - final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_contains" ).generateSqmExpression( - asList( lhs, rhs ), - null, - creationContext.getQueryEngine() - ); - return new SqmBooleanExpressionPredicate( contains, negated, creationContext.getNodeBuilder() ); - } - - @Override - public SqmPredicate visitIncludesPredicate(HqlParser.IncludesPredicateContext ctx) { - final boolean negated = ctx.NOT() != null; - final SqmExpression lhs = (SqmExpression) ctx.expression( 0 ).accept( this ); - final SqmExpression rhs = (SqmExpression) ctx.expression( 1 ).accept( this ); - final SqmExpressible lhsExpressible = lhs.getExpressible(); - final SqmExpressible rhsExpressible = rhs.getExpressible(); - if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "First operand for includes predicate must be a basic plural type expression, but found: " + lhsExpressible.getSqmType(), - query - ); - } - if ( rhsExpressible != null && !( rhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "Second operand for includes predicate must be a basic plural type expression, but found: " + rhsExpressible.getSqmType(), - query - ); - } - final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_includes" ).generateSqmExpression( - asList( lhs, rhs ), - null, - creationContext.getQueryEngine() - ); - return new SqmBooleanExpressionPredicate( contains, negated, creationContext.getNodeBuilder() ); - } - - @Override - public SqmPredicate visitIntersectsPredicate(HqlParser.IntersectsPredicateContext ctx) { - final boolean negated = ctx.NOT() != null; - final SqmExpression lhs = (SqmExpression) ctx.expression( 0 ).accept( this ); - final SqmExpression rhs = (SqmExpression) ctx.expression( 1 ).accept( this ); - final SqmExpressible lhsExpressible = lhs.getExpressible(); - if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "First operand for intersects predicate must be a basic plural type expression, but found: " + lhsExpressible.getSqmType(), - query - ); - } - final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_intersects" ).generateSqmExpression( - asList( lhs, rhs ), - null, - creationContext.getQueryEngine() - ); - return new SqmBooleanExpressionPredicate( contains, negated, creationContext.getNodeBuilder() ); - } - @Override public SqmPredicate visitLikePredicate(HqlParser.LikePredicateContext ctx) { final boolean negated = ctx.NOT() != null; @@ -3061,11 +3058,6 @@ else if ( attributes.size() >1 ) { throw new FunctionArgumentException( "Argument '" + sqmPath.getNavigablePath() + "' of 'naturalid()' does not resolve to an entity type" ); } -// -// @Override -// public Object visitToOneFkExpression(HqlParser.ToOneFkExpressionContext ctx) { -// return visitToOneFkReference( (HqlParser.ToOneFkReferenceContext) ctx.getChild( 0 ) ); -// } @Override public SqmFkExpression visitToOneFkReference(HqlParser.ToOneFkReferenceContext ctx) { @@ -5309,33 +5301,6 @@ else if ( ctx.collectionValueNavigablePath() != null ) { else if ( ctx.mapKeyNavigablePath() != null ) { return visitMapKeyNavigablePath( ctx.mapKeyNavigablePath() ); } - else if ( ctx.toOneFkReference() != null ) { - return visitToOneFkReference( ctx.toOneFkReference() ); - } - else if ( ctx.function() != null ) { - final HqlParser.SlicedPathAccessFragmentContext slicedFragmentsCtx = ctx.slicedPathAccessFragment(); - if ( slicedFragmentsCtx != null ) { - final List slicedFragments = slicedFragmentsCtx.expression(); - return getFunctionDescriptor( "array_slice" ).generateSqmExpression( - List.of( - (SqmTypedNode) visitFunction( ctx.function() ), - (SqmTypedNode) slicedFragments.get( 0 ).accept( this ), - (SqmTypedNode) slicedFragments.get( 1 ).accept( this ) - ), - null, - creationContext.getQueryEngine() - ); - } - else { - return visitPathContinuation( - visitIndexedPathAccessFragment( - (SemanticPathPart) visitFunction( ctx.function() ), - ctx.indexedPathAccessFragment() - ), - ctx.pathContinuation() - ); - } - } else if ( ctx.simplePath() != null && ctx.indexedPathAccessFragment() != null ) { return visitIndexedPathAccessFragment( visitSimplePath( ctx.simplePath() ), ctx.indexedPathAccessFragment() ); } diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java index 191a6a05940d..7a831a0e1b08 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java @@ -132,9 +132,14 @@ public void reportContextSensitivity(Parser recognizer, DFA dfa, int startIndex, try { return hqlParser.statement(); } - catch ( ParseCancellationException e) { + catch (ParseCancellationException e) { + // When resetting the parser, its CommonTokenStream will seek(0) i.e. restart emitting buffered tokens. + // This is enough when reusing the lexer and parser, and it would be wrong to also reset the lexer. + // Resetting the lexer causes it to hand out tokens again from the start, which will then append to the + // CommonTokenStream and cause a wrong parse + // hqlLexer.reset(); + // reset the input token stream and parser state - hqlLexer.reset(); hqlParser.reset(); // fall back to LL(k)-based parsing diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlParserMemoryUsageTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlParserMemoryUsageTest.java new file mode 100644 index 000000000000..ec511f1471fd --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlParserMemoryUsageTest.java @@ -0,0 +1,164 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.orm.test.hql; + +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.Id; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.OneToMany; +import jakarta.persistence.Table; +import org.hibernate.cfg.QuerySettings; +import org.hibernate.query.hql.HqlTranslator; +import org.hibernate.testing.memory.MemoryUsageUtil; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.Jira; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.Setting; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +@DomainModel( + annotatedClasses = { + HqlParserMemoryUsageTest.Address.class, + HqlParserMemoryUsageTest.AppUser.class, + HqlParserMemoryUsageTest.Category.class, + HqlParserMemoryUsageTest.Discount.class, + HqlParserMemoryUsageTest.Order.class, + HqlParserMemoryUsageTest.OrderItem.class, + HqlParserMemoryUsageTest.Product.class + } +) +@SessionFactory +@ServiceRegistry(settings = @Setting(name = QuerySettings.QUERY_PLAN_CACHE_ENABLED, value = "false")) +@Jira("https://hibernate.atlassian.net/browse/HHH-19240") +public class HqlParserMemoryUsageTest { + + private static final String HQL = "SELECT DISTINCT u.id\n" + + "FROM AppUser u\n" + + "LEFT JOIN u.addresses a\n" + + "LEFT JOIN u.orders o\n" + + "LEFT JOIN o.orderItems oi\n" + + "LEFT JOIN oi.product p\n" + + "LEFT JOIN p.discounts d\n" + + "WHERE u.id = :userId\n" + + "AND (\n" + + " CASE\n" + + " WHEN u.name = 'SPECIAL_USER' THEN TRUE\n" + + " ELSE (\n" + + " CASE\n" + + " WHEN a.city = 'New York' THEN TRUE\n" + + " ELSE (\n" + + " p.category.name = 'Electronics'\n" + + " OR d.code LIKE '%DISC%'\n" + + " OR u.id IN (\n" + + " SELECT u2.id\n" + + " FROM AppUser u2\n" + + " JOIN u2.orders o2\n" + + " JOIN o2.orderItems oi2\n" + + " JOIN oi2.product p2\n" + + " WHERE p2.price > (\n" + + " SELECT AVG(p3.price) FROM Product p3\n" + + " )\n" + + " )\n" + + " )\n" + + " END\n" + + " )\n" + + " END\n" + + ")\n"; + + + @Test + public void testParserMemoryUsage(SessionFactoryScope scope) { + final HqlTranslator hqlTranslator = scope.getSessionFactory().getQueryEngine().getHqlTranslator(); + + // Ensure classes and basic stuff is initialized in case this is the first test run + hqlTranslator.translate( "from AppUser", AppUser.class ); + + // During testing, before the fix for HHH-19240, the allocation was around 500+ MB, + // and after the fix it dropped to 170 - 250 MB + final long memoryUsage = MemoryUsageUtil.estimateMemoryUsage( () -> hqlTranslator.translate( HQL, Long.class ) ); + System.out.println( "Memory Consumption: " + (memoryUsage / 1024) + " KB" ); + assertTrue( memoryUsage < 256_000_000, "Parsing of queries consumes too much memory (" + ( memoryUsage / 1024 ) + " KB), when at most 256 MB are expected" ); + } + + @Entity(name = "Address") + @Table(name = "addresses") + public static class Address { + @Id + private Long id; + private String city; + @ManyToOne(fetch = FetchType.LAZY) + private AppUser user; + } + @Entity(name = "AppUser") + @Table(name = "app_users") + public static class AppUser { + @Id + private Long id; + private String name; + @OneToMany(mappedBy = "user") + private Set
addresses; + @OneToMany(mappedBy = "user") + private Set orders; + } + + @Entity(name = "Category") + @Table(name = "categories") + public static class Category { + @Id + private Long id; + private String name; + } + + @Entity(name = "Discount") + @Table(name = "discounts") + public static class Discount { + @Id + private Long id; + private String code; + @ManyToOne(fetch = FetchType.LAZY) + private Product product; + } + + @Entity(name = "Order") + @Table(name = "orders") + public static class Order { + @Id + private Long id; + @ManyToOne(fetch = FetchType.LAZY) + private AppUser user; + @OneToMany(mappedBy = "order") + private Set orderItems; + } + @Entity(name = "OrderItem") + @Table(name = "order_items") + public static class OrderItem { + @Id + private Long id; + @ManyToOne(fetch = FetchType.LAZY) + private Order order; + @ManyToOne(fetch = FetchType.LAZY) + private Product product; + } + + @Entity(name = "Product") + @Table(name = "products") + public static class Product { + @Id + private Long id; + private String name; + private Double price; + @ManyToOne(fetch = FetchType.LAZY) + private Category category; + @OneToMany(mappedBy = "product") + private Set discounts; + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/GlobalMemoryUsageSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/GlobalMemoryUsageSnapshotter.java new file mode 100644 index 000000000000..05cabd0bc432 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/GlobalMemoryUsageSnapshotter.java @@ -0,0 +1,100 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.testing.memory; + +import java.lang.management.ManagementFactory; +import java.lang.management.MemoryPoolMXBean; +import java.util.List; +import java.util.Objects; + +final class GlobalMemoryUsageSnapshotter implements MemoryAllocationSnapshotter { + + private static final GlobalMemoryUsageSnapshotter INSTANCE = new GlobalMemoryUsageSnapshotter( + ManagementFactory.getMemoryPoolMXBeans() + ); + + private final List heapPoolBeans; + private final Runnable gcAndWait; + + private GlobalMemoryUsageSnapshotter(List heapPoolBeans) { + this.heapPoolBeans = heapPoolBeans; + this.gcAndWait = () -> { + for (int i = 0; i < 3; i++) { + System.gc(); + try { + Thread.sleep( 50 ); + } + catch (InterruptedException ignored) { + } + } + }; + } + + public static GlobalMemoryUsageSnapshotter getInstance() { + return INSTANCE; + } + + @Override + public MemoryAllocationSnapshot snapshot() { + final long peakUsage = heapPoolBeans.stream().mapToLong(p -> p.getPeakUsage().getUsed()).sum(); + gcAndWait.run(); + final long retainedUsage = heapPoolBeans.stream().mapToLong(p -> p.getUsage().getUsed()).sum(); + heapPoolBeans.forEach(MemoryPoolMXBean::resetPeakUsage); + return new GlobalMemoryAllocationSnapshot( peakUsage, retainedUsage ); + } + + final static class GlobalMemoryAllocationSnapshot implements MemoryAllocationSnapshot { + private final long peakUsage; + private final long retainedUsage; + + GlobalMemoryAllocationSnapshot(long peakUsage, long retainedUsage) { + this.peakUsage = peakUsage; + this.retainedUsage = retainedUsage; + } + + public long peakUsage() { + return peakUsage; + } + + public long retainedUsage() { + return retainedUsage; + } + + @Override + public long difference(MemoryAllocationSnapshot before) { + // When doing the "before" snapshot, the peak usage is reset. + // Since this object is the "after" snapshot, we can simply estimate the memory usage of an operation + // to be the peak usage of that operation minus the usage after GC + return peakUsage - retainedUsage; + } + + @Override + public boolean equals(Object obj) { + if ( obj == this ) { + return true; + } + if ( obj == null || obj.getClass() != this.getClass() ) { + return false; + } + var that = (GlobalMemoryAllocationSnapshot) obj; + return this.peakUsage == that.peakUsage && + this.retainedUsage == that.retainedUsage; + } + + @Override + public int hashCode() { + return Objects.hash( peakUsage, retainedUsage ); + } + + @Override + public String toString() { + return "GlobalMemoryAllocationSnapshot[" + + "peakUsage=" + peakUsage + ", " + + "retainedUsage=" + retainedUsage + ']'; + } + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotPerThreadAllocationSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotPerThreadAllocationSnapshotter.java new file mode 100644 index 000000000000..a767d8a758d9 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotPerThreadAllocationSnapshotter.java @@ -0,0 +1,169 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.testing.memory; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Objects; + +final class HotspotPerThreadAllocationSnapshotter implements MemoryAllocationSnapshotter { + + private static final @Nullable HotspotPerThreadAllocationSnapshotter INSTANCE; + private static final Method GET_THREAD_ALLOCATED_BYTES; + + static { + ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); + Method method = null; + try { + @SuppressWarnings("unchecked") + Class hotspotInterface = + (Class) Class.forName( "com.sun.management.ThreadMXBean" ); + try { + method = hotspotInterface.getMethod( "getThreadAllocatedBytes", long[].class ); + } + catch (Exception e) { + // Ignore + } + + if ( !hotspotInterface.isInstance( threadMXBean ) ) { + threadMXBean = ManagementFactory.getPlatformMXBean( hotspotInterface ); + } + } + catch (Throwable e) { + // Ignore + } + + GET_THREAD_ALLOCATED_BYTES = method; + + HotspotPerThreadAllocationSnapshotter instance = null; + if ( method != null && threadMXBean != null ) { + try { + instance = new HotspotPerThreadAllocationSnapshotter( threadMXBean ); + instance.snapshot(); + } + catch (Exception e) { + instance = null; + } + } + INSTANCE = instance; + } + + public static @Nullable HotspotPerThreadAllocationSnapshotter getInstance() { + return INSTANCE; + } + + @Override + public MemoryAllocationSnapshot snapshot() { + long[] threadIds = threadMXBean.getAllThreadIds(); + try { + return new PerThreadMemoryAllocationSnapshot( + threadIds, + (long[]) GET_THREAD_ALLOCATED_BYTES.invoke( threadMXBean, (Object) threadIds ) + ); + } + catch (Exception e) { + throw new RuntimeException( e ); + } + } + + final static class PerThreadMemoryAllocationSnapshot implements MemoryAllocationSnapshot { + private final long[] threadIds; + private final long[] threadAllocatedBytes; + + PerThreadMemoryAllocationSnapshot(long[] threadIds, long[] threadAllocatedBytes) { + this.threadIds = threadIds; + this.threadAllocatedBytes = threadAllocatedBytes; + } + + public long[] threadIds() { + return threadIds; + } + + public long[] threadAllocatedBytes() { + return threadAllocatedBytes; + } + + @Override + public long difference(MemoryAllocationSnapshot before) { + final PerThreadMemoryAllocationSnapshot other = (PerThreadMemoryAllocationSnapshot) before; + final HashMap previousThreadIdToIndexMap = new HashMap<>(); + for ( int i = 0; i < other.threadIds.length; i++ ) { + previousThreadIdToIndexMap.put( other.threadIds[i], i ); + } + long allocatedBytes = 0; + for ( int i = 0; i < threadIds.length; i++ ) { + allocatedBytes += threadAllocatedBytes[i]; + final Integer previousThreadIndex = previousThreadIdToIndexMap.get( threadIds[i] ); + if ( previousThreadIndex != null ) { + allocatedBytes -= other.threadAllocatedBytes[previousThreadIndex]; + } + } + return allocatedBytes; + } + + @Override + public boolean equals(Object obj) { + if ( obj == this ) { + return true; + } + if ( obj == null || obj.getClass() != this.getClass() ) { + return false; + } + var that = (PerThreadMemoryAllocationSnapshot) obj; + return Objects.equals( this.threadIds, that.threadIds ) && + Objects.equals( this.threadAllocatedBytes, that.threadAllocatedBytes ); + } + + @Override + public int hashCode() { + return Objects.hash( threadIds, threadAllocatedBytes ); + } + + @Override + public String toString() { + return "PerThreadMemoryAllocationSnapshot[" + + "threadIds=" + threadIds + ", " + + "threadAllocatedBytes=" + threadAllocatedBytes + ']'; + } + } + private final ThreadMXBean threadMXBean; + + HotspotPerThreadAllocationSnapshotter(ThreadMXBean threadMXBean) { + this.threadMXBean = threadMXBean; + } + + public ThreadMXBean threadMXBean() { + return threadMXBean; + } + + @Override + public boolean equals(Object obj) { + if ( obj == this ) { + return true; + } + if ( obj == null || obj.getClass() != this.getClass() ) { + return false; + } + var that = (HotspotPerThreadAllocationSnapshotter) obj; + return Objects.equals( this.threadMXBean, that.threadMXBean ); + } + + @Override + public int hashCode() { + return Objects.hash( threadMXBean ); + } + + @Override + public String toString() { + return "HotspotPerThreadAllocationSnapshotter[" + + "threadMXBean=" + threadMXBean + ']'; + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotTotalThreadBytesSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotTotalThreadBytesSnapshotter.java new file mode 100644 index 000000000000..238eaeb7539f --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotTotalThreadBytesSnapshotter.java @@ -0,0 +1,148 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.testing.memory; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.lang.reflect.Method; +import java.util.Objects; + +final class HotspotTotalThreadBytesSnapshotter implements MemoryAllocationSnapshotter { + + private static final @Nullable HotspotTotalThreadBytesSnapshotter INSTANCE; + private static final Method GET_TOTAL_THREAD_ALLOCATED_BYTES; + + static { + ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); + Method method = null; + try { + @SuppressWarnings("unchecked") + Class hotspotInterface = + (Class) Class.forName( "com.sun.management.ThreadMXBean" ); + try { + method = hotspotInterface.getMethod( "getTotalThreadAllocatedBytes" ); + } + catch (Exception e) { + // Ignore + } + + if ( !hotspotInterface.isInstance( threadMXBean ) ) { + threadMXBean = ManagementFactory.getPlatformMXBean( hotspotInterface ); + } + } + catch (Throwable e) { + // Ignore + } + + GET_TOTAL_THREAD_ALLOCATED_BYTES = method; + + HotspotTotalThreadBytesSnapshotter instance = null; + if ( method != null && threadMXBean != null ) { + try { + instance = new HotspotTotalThreadBytesSnapshotter( threadMXBean ); + instance.snapshot(); + } + catch (Exception e) { + instance = null; + } + } + INSTANCE = instance; + } + + public static @Nullable HotspotTotalThreadBytesSnapshotter getInstance() { + return INSTANCE; + } + + @Override + public MemoryAllocationSnapshot snapshot() { + try { + return new GlobalMemoryAllocationSnapshot( (long) GET_TOTAL_THREAD_ALLOCATED_BYTES.invoke( threadMXBean ) ); + } + catch (Exception e) { + throw new RuntimeException( e ); + } + } + + final static class GlobalMemoryAllocationSnapshot implements MemoryAllocationSnapshot { + private final long allocatedBytes; + + GlobalMemoryAllocationSnapshot(long allocatedBytes) { + if ( allocatedBytes == -1L ) { + throw new IllegalArgumentException( "getTotalThreadAllocatedBytes is disabled" ); + } + this.allocatedBytes = allocatedBytes; + } + + @Override + public long difference(MemoryAllocationSnapshot before) { + final GlobalMemoryAllocationSnapshot other = (GlobalMemoryAllocationSnapshot) before; + return Math.max( allocatedBytes - other.allocatedBytes, 0L ); + } + + public long allocatedBytes() { + return allocatedBytes; + } + + @Override + public boolean equals(Object obj) { + if ( obj == this ) { + return true; + } + if ( obj == null || obj.getClass() != this.getClass() ) { + return false; + } + var that = (GlobalMemoryAllocationSnapshot) obj; + return this.allocatedBytes == that.allocatedBytes; + } + + @Override + public int hashCode() { + return Objects.hash( allocatedBytes ); + } + + @Override + public String toString() { + return "GlobalMemoryAllocationSnapshot[" + + "allocatedBytes=" + allocatedBytes + ']'; + } + } + + private final ThreadMXBean threadMXBean; + + HotspotTotalThreadBytesSnapshotter(ThreadMXBean threadMXBean) { + this.threadMXBean = threadMXBean; + } + + public ThreadMXBean threadMXBean() { + return threadMXBean; + } + + @Override + public boolean equals(Object obj) { + if ( obj == this ) { + return true; + } + if ( obj == null || obj.getClass() != this.getClass() ) { + return false; + } + var that = (HotspotTotalThreadBytesSnapshotter) obj; + return Objects.equals( this.threadMXBean, that.threadMXBean ); + } + + @Override + public int hashCode() { + return Objects.hash( threadMXBean ); + } + + @Override + public String toString() { + return "HotspotTotalThreadBytesSnapshotter[" + + "threadMXBean=" + threadMXBean + ']'; + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshot.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshot.java new file mode 100644 index 000000000000..1de3f1afaf5d --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshot.java @@ -0,0 +1,11 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.testing.memory; + +interface MemoryAllocationSnapshot { + long difference(MemoryAllocationSnapshot before); +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshotter.java new file mode 100644 index 000000000000..a0bf0192cb57 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshotter.java @@ -0,0 +1,11 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.testing.memory; + +interface MemoryAllocationSnapshotter { + MemoryAllocationSnapshot snapshot(); +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryUsageUtil.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryUsageUtil.java new file mode 100644 index 000000000000..9f352de0da16 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryUsageUtil.java @@ -0,0 +1,29 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.testing.memory; + +public class MemoryUsageUtil { + + private static final MemoryAllocationSnapshotter SNAPSHOTTER; + + static { + MemoryAllocationSnapshotter snapshotter = HotspotTotalThreadBytesSnapshotter.getInstance(); + if ( snapshotter == null ) { + snapshotter = HotspotPerThreadAllocationSnapshotter.getInstance(); + } + if ( snapshotter == null ) { + snapshotter = GlobalMemoryUsageSnapshotter.getInstance(); + } + SNAPSHOTTER = snapshotter; + } + + public static long estimateMemoryUsage(Runnable runnable) { + final MemoryAllocationSnapshot beforeSnapshot = SNAPSHOTTER.snapshot(); + runnable.run(); + return SNAPSHOTTER.snapshot().difference( beforeSnapshot ); + } +}