From 34d5a2ae7aadfece890bff4a1446c9c893e83713 Mon Sep 17 00:00:00 2001 From: Nathan Xu Date: Tue, 24 Mar 2020 15:04:30 -0400 Subject: [PATCH] implement @Filter for loader --- .../cfg/annotations/CollectionBinder.java | 4 +- .../hibernate/engine/spi/QueryParameters.java | 74 ----- .../org/hibernate/internal/FilterHelper.java | 78 ++++- .../internal/CollectionLoaderBatchKey.java | 5 +- .../internal/CollectionLoaderSingleKey.java | 4 +- .../ast/internal/LoaderSelectBuilder.java | 72 ++++- .../loader/ast/internal/SingleIdLoadPlan.java | 2 + .../AbstractCollectionPersister.java | 5 +- .../entity/AbstractEntityPersister.java | 28 +- .../org/hibernate/sql/ast/SqlTreePrinter.java | 7 +- .../sql/ast/spi/AbstractSqlAstWalker.java | 9 +- .../sql/ast/spi/SqlAstTreeHelper.java | 1 + .../ast/tree/predicate/FilterPredicate.java | 31 +- .../sql/ast/tree/select/QuerySpec.java | 45 ++- .../loading/filter/FilterBasicsTests.java | 281 ++++++++++++++++ .../loading/filter/FilterJoinTableTests.java | 203 ++++++++++++ .../FilterOnEagerLoadedCollectionTests.java | 304 ++++++++++++++++++ 17 files changed, 1065 insertions(+), 88 deletions(-) create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterBasicsTests.java create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterJoinTableTests.java create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterOnEagerLoadedCollectionTests.java diff --git a/hibernate-core/src/main/java/org/hibernate/cfg/annotations/CollectionBinder.java b/hibernate-core/src/main/java/org/hibernate/cfg/annotations/CollectionBinder.java index 1f06e12bd3dd..f7d5e6ed4966 100644 --- a/hibernate-core/src/main/java/org/hibernate/cfg/annotations/CollectionBinder.java +++ b/hibernate-core/src/main/java/org/hibernate/cfg/annotations/CollectionBinder.java @@ -975,7 +975,7 @@ private void bindFilters(boolean hasAssociationTable) { FilterJoinTable simpleFilterJoinTable = property.getAnnotation( FilterJoinTable.class ); if ( simpleFilterJoinTable != null ) { if ( hasAssociationTable ) { - collection.addFilter(simpleFilterJoinTable.name(), simpleFilterJoinTable.condition(), + collection.addManyToManyFilter(simpleFilterJoinTable.name(), simpleFilterJoinTable.condition(), simpleFilterJoinTable.deduceAliasInjectionPoints(), toAliasTableMap(simpleFilterJoinTable.aliases()), toAliasEntityMap(simpleFilterJoinTable.aliases())); } @@ -990,7 +990,7 @@ private void bindFilters(boolean hasAssociationTable) { if ( filterJoinTables != null ) { for (FilterJoinTable filter : filterJoinTables.value()) { if ( hasAssociationTable ) { - collection.addFilter(filter.name(), filter.condition(), + collection.addManyToManyFilter(filter.name(), filter.condition(), filter.deduceAliasInjectionPoints(), toAliasTableMap(filter.aliases()), toAliasEntityMap(filter.aliases())); } diff --git a/hibernate-core/src/main/java/org/hibernate/engine/spi/QueryParameters.java b/hibernate-core/src/main/java/org/hibernate/engine/spi/QueryParameters.java index 7c8a69219351..0cb0b6a9ce76 100644 --- a/hibernate-core/src/main/java/org/hibernate/engine/spi/QueryParameters.java +++ b/hibernate-core/src/main/java/org/hibernate/engine/spi/QueryParameters.java @@ -515,80 +515,6 @@ public void setPassDistinctThrough(boolean passDistinctThrough) { this.passDistinctThrough = passDistinctThrough; } - public void processFilters(String sql, SharedSessionContractImplementor session) { - processFilters( sql, session.getLoadQueryInfluencers().getEnabledFilters(), session.getFactory() ); - } - - @SuppressWarnings( {"unchecked"}) - public void processFilters(String sql, Map filters, SessionFactoryImplementor factory) { - if ( filters.size() == 0 || !sql.contains( HQL_VARIABLE_PREFIX ) ) { - // HELLA IMPORTANT OPTIMIZATION!!! - processedPositionalParameterValues = getPositionalParameterValues(); - processedPositionalParameterTypes = getPositionalParameterTypes(); - processedSQL = sql; - } - else { - throw new NotYetImplementedFor6Exception( getClass() ); -// final StringTokenizer tokens = new StringTokenizer( sql, SYMBOLS, true ); -// StringBuilder result = new StringBuilder(); -// List parameters = new ArrayList(); -// List parameterTypes = new ArrayList(); -// int positionalIndex = 0; -// while ( tokens.hasMoreTokens() ) { -// final String token = tokens.nextToken(); -// if ( token.startsWith( ParserHelper.HQL_VARIABLE_PREFIX ) ) { -// final String filterParameterName = token.substring( 1 ); -// final String[] parts = LoadQueryInfluencers.parseFilterParameterName( filterParameterName ); -// final FilterImpl filter = (FilterImpl) filters.get( parts[0] ); -// final Object value = filter.getParameter( parts[1] ); -// final Type type = filter.getFilterDefinition().getParameterType( parts[1] ); -// if ( value != null && Collection.class.isAssignableFrom( value.getClass() ) ) { -// Iterator itr = ( (Collection) value ).iterator(); -// while ( itr.hasNext() ) { -// final Object elementValue = itr.next(); -// result.append( '?' ); -// parameters.add( elementValue ); -// parameterTypes.add( type ); -// if ( itr.hasNext() ) { -// result.append( ", " ); -// } -// } -// } -// else { -// result.append( '?' ); -// parameters.add( value ); -// parameterTypes.add( type ); -// } -// } -// else { -// result.append( token ); -// if ( "?".equals( token ) && positionalIndex < getPositionalParameterValues().length ) { -// final Type type = getPositionalParameterTypes()[positionalIndex]; -// if ( type.isComponentType() ) { -// // should process tokens till reaching the number of "?" corresponding to the -// // numberOfParametersCoveredBy of the compositeType -// int paramIndex = 1; -// final int numberOfParametersCoveredBy = getNumberOfParametersCoveredBy( ((ComponentType) type).getSubtypes() ); -// while ( paramIndex < numberOfParametersCoveredBy ) { -// final String nextToken = tokens.nextToken(); -// if ( "?".equals( nextToken ) ) { -// paramIndex++; -// } -// result.append( nextToken ); -// } -// } -// parameters.add( getPositionalParameterValues()[positionalIndex] ); -// parameterTypes.add( type ); -// positionalIndex++; -// } -// } -// } -// processedPositionalParameterValues = parameters.toArray(); -// processedPositionalParameterTypes = ( Type[] ) parameterTypes.toArray( new Type[parameterTypes.size()] ); -// processedSQL = result.toString(); - } - } - private int getNumberOfParametersCoveredBy(Type[] subtypes) { int numberOfParameters = 0; for ( Type type : subtypes ) { diff --git a/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java b/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java index 9d2be96ceee6..62d3763ca587 100644 --- a/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java +++ b/hibernate-core/src/main/java/org/hibernate/internal/FilterHelper.java @@ -6,13 +6,19 @@ */ package org.hibernate.internal; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.hibernate.Filter; +import org.hibernate.HibernateException; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.internal.util.StringHelper; import org.hibernate.internal.util.collections.CollectionHelper; import org.hibernate.sql.Template; +import org.hibernate.type.Type; import static org.hibernate.internal.util.StringHelper.safeInterning; @@ -21,9 +27,12 @@ * * @author Steve Ebersole * @author Rob Worsnop + * @author Nathan Xu */ public class FilterHelper { + private static Pattern FILTER_PARAMETER_PATTERN = Pattern.compile( ":(\\w+)\\.(\\w+)" ); + private final String[] filterNames; private final String[] filterConditions; private final boolean[] filterAutoAliasFlags; @@ -99,7 +108,10 @@ public void render(StringBuilder buffer, FilterAliasGenerator aliasGenerator, Ma if ( enabledFilters.containsKey( filterNames[i] ) ) { final String condition = filterConditions[i]; if ( StringHelper.isNotEmpty( condition ) ) { - buffer.append( " and " ).append( render( aliasGenerator, i ) ); + if ( buffer.length() > 0 ) { + buffer.append( " and " ); + } + buffer.append( render( aliasGenerator, i ) ); } } } @@ -128,4 +140,68 @@ else if ( isTableFromPersistentClass( aliasTableMap ) ) { return condition; } } + + public static class TypedValue { + private final Type type; + private final Object value; + + public TypedValue(Type type, Object value) { + this.type = type; + this.value = value; + } + + public Type getType() { + return type; + } + + public Object getValue() { + return value; + } + } + + public static class TransformResult { + private final String transformedFilterFragment; + private final List parameters; + + public TransformResult( + String transformedFilterFragment, + List parameters) { + this.transformedFilterFragment = transformedFilterFragment; + this.parameters = parameters; + } + + public String getTransformedFilterFragment() { + return transformedFilterFragment; + } + + public List getParameters() { + return parameters; + } + } + + public static TransformResult transformToPositionalParameters(String filterFragment, Map enabledFilters) { + final Matcher matcher = FILTER_PARAMETER_PATTERN.matcher( filterFragment ); + final StringBuilder sb = new StringBuilder(); + int pos = 0; + final List parameters = new ArrayList<>( matcher.groupCount() ); + while( matcher.find() ) { + sb.append( filterFragment, pos, matcher.start() ); + pos = matcher.end(); + sb.append( "?" ); + final String filterName = matcher.group( 1 ); + final String parameterName = matcher.group( 2 ); + final FilterImpl enabledFilter = (FilterImpl) enabledFilters.get( filterName ); + if ( enabledFilter == null ) { + throw new HibernateException( String.format( "unknown filter [%s]", filterName ) ); + } + final Type parameterType = enabledFilter.getFilterDefinition().getParameterType( parameterName ); + final Object parameterValue = enabledFilter.getParameter( parameterName ); + if ( parameterValue == null ) { + throw new HibernateException( String.format( "unknown parameter [%s] for filter [%s]", parameterName, filterName ) ); + } + parameters.add( new TypedValue( parameterType, parameterValue ) ); + } + sb.append( filterFragment, pos, filterFragment.length() ); + return new TransformResult( sb.toString(), parameters ); + } } diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java index 8c869cce5475..a436b9029a93 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderBatchKey.java @@ -26,11 +26,11 @@ import org.hibernate.query.spi.QueryParameterBindings; import org.hibernate.sql.ast.Clause; import org.hibernate.sql.ast.SqlAstTranslatorFactory; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.exec.internal.JdbcParameterBindingsImpl; import org.hibernate.sql.exec.spi.Callback; import org.hibernate.sql.exec.spi.ExecutionContext; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.exec.spi.JdbcParameterBinding; import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.exec.spi.JdbcSelect; @@ -163,6 +163,9 @@ private void batchLoad( final JdbcSelect jdbcSelect = sqlAstTranslatorFactory.buildSelectTranslator( sessionFactory ).translate( sqlAst ); final JdbcParameterBindings jdbcParameterBindings = new JdbcParameterBindingsImpl( keyJdbcCount * smallBatchLength ); + + sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings ); + final Iterator paramItr = jdbcParameters.iterator(); for ( int i = smallBatchStart; i < smallBatchStart + smallBatchLength; i++ ) { diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java index 9c79c339aa93..20558503993d 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/CollectionLoaderSingleKey.java @@ -25,11 +25,11 @@ import org.hibernate.query.spi.QueryParameterBindings; import org.hibernate.sql.ast.Clause; import org.hibernate.sql.ast.SqlAstTranslatorFactory; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.exec.internal.JdbcParameterBindingsImpl; import org.hibernate.sql.exec.spi.Callback; import org.hibernate.sql.exec.spi.ExecutionContext; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.exec.spi.JdbcParameterBinding; import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.exec.spi.JdbcSelect; @@ -127,6 +127,8 @@ public Object getBindValue() { ); assert !paramItr.hasNext(); + sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings ); + jdbcServices.getJdbcSelectExecutor().list( jdbcSelect, jdbcParameterBindings, diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java index 3368f4632700..72d656c13a07 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/LoaderSelectBuilder.java @@ -23,6 +23,8 @@ import org.hibernate.engine.spi.LoadQueryInfluencers; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.engine.spi.SubselectFetch; +import org.hibernate.internal.FilterHelper; +import org.hibernate.internal.FilterHelper.TransformResult; import org.hibernate.loader.ast.spi.Loadable; import org.hibernate.loader.ast.spi.Loader; import org.hibernate.metamodel.mapping.BasicValuedModelPart; @@ -33,6 +35,8 @@ import org.hibernate.metamodel.mapping.PluralAttributeMapping; import org.hibernate.metamodel.mapping.internal.SimpleForeignKeyDescriptor; import org.hibernate.metamodel.mapping.ordering.OrderByFragment; +import org.hibernate.persister.collection.AbstractCollectionPersister; +import org.hibernate.persister.entity.Joinable; import org.hibernate.query.ComparisonOperator; import org.hibernate.query.NavigablePath; import org.hibernate.sql.ast.spi.SimpleFromClauseAccessImpl; @@ -44,8 +48,11 @@ import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.SqlTuple; import org.hibernate.sql.ast.tree.from.TableGroup; +import org.hibernate.sql.ast.tree.from.TableGroupJoin; import org.hibernate.sql.ast.tree.from.TableReference; +import org.hibernate.sql.ast.tree.from.TableReferenceJoin; import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate; +import org.hibernate.sql.ast.tree.predicate.FilterPredicate; import org.hibernate.sql.ast.tree.predicate.InListPredicate; import org.hibernate.sql.ast.tree.predicate.InSubQueryPredicate; import org.hibernate.sql.ast.tree.select.QuerySpec; @@ -222,6 +229,7 @@ private SelectStatement generateSelect() { sqlAstCreationState.getFromClauseAccess().registerTableGroup( rootNavigablePath, rootTableGroup ); if ( loadable instanceof PluralAttributeMapping ) { + applyFiltering( rootQuerySpec, loadQueryInfluencers, (PluralAttributeMapping) loadable ); applyOrdering( rootTableGroup, (PluralAttributeMapping) loadable ); } @@ -372,6 +380,66 @@ private void applyKeyRestriction( } } + private void applyFiltering( + QuerySpec querySpec, + LoadQueryInfluencers loadQueryInfluencers, + PluralAttributeMapping pluralAttributeMapping) { + if ( loadQueryInfluencers.hasEnabledFilters() ) { + final Joinable joinable = pluralAttributeMapping + .getCollectionDescriptor() + .getCollectionType() + .getAssociatedJoinable( creationContext.getSessionFactory() ); + assert joinable instanceof AbstractCollectionPersister; + final AbstractCollectionPersister collectionPersister = (AbstractCollectionPersister) joinable; + querySpec.getFromClause().getRoots().forEach( tableGroup -> consumeTableAliasByTableExpression( + tableGroup, + joinable.getTableName(), + alias -> { + final boolean isManyToMany = collectionPersister.isManyToMany(); + String filterFragment; + if ( isManyToMany ) { + filterFragment = collectionPersister.getManyToManyFilterFragment( + alias, + loadQueryInfluencers.getEnabledFilters() + ); + } + else { + filterFragment = collectionPersister.filterFragment( + alias, + loadQueryInfluencers.getEnabledFilters() + ); + } + final TransformResult transformResult = FilterHelper.transformToPositionalParameters( + filterFragment, loadQueryInfluencers.getEnabledFilters() + ); + filterFragment = transformResult.getTransformedFilterFragment(); + final FilterPredicate filterPredicate = new FilterPredicate( + filterFragment, transformResult.getParameters() + ); + querySpec.applyPredicate( filterPredicate ); + querySpec.addFilterPredicate( filterPredicate ); + } + ) + ); + } + } + + private void consumeTableAliasByTableExpression(TableGroup tableGroup, String tableExpression, Consumer aliasConsumer) { + if ( tableExpression.equals( tableGroup.getPrimaryTableReference().getTableExpression() ) ) { + aliasConsumer.accept( tableGroup.getPrimaryTableReference().getIdentificationVariable() ); + } + else { + for ( TableReferenceJoin referenceJoin : tableGroup.getTableReferenceJoins() ) { + if ( tableExpression.equals( referenceJoin.getJoinedTableReference().getTableExpression() ) ) { + aliasConsumer.accept( referenceJoin.getJoinedTableReference().getIdentificationVariable() ); + } + } + for ( TableGroupJoin tableGroupJoin : tableGroup.getTableGroupJoins() ) { + consumeTableAliasByTableExpression( tableGroupJoin.getJoinedGroup(), tableExpression, aliasConsumer ); + } + } + } + private void applyOrdering(TableGroup tableGroup, PluralAttributeMapping pluralAttributeMapping) { if ( pluralAttributeMapping.getOrderByFragment() != null ) { applyOrdering( tableGroup, pluralAttributeMapping.getOrderByFragment() ); @@ -478,7 +546,8 @@ else if ( fetchDepth > maximumFetchDepth ) { ); fetches.add( fetch ); - if ( fetchable instanceof PluralAttributeMapping && fetchTiming == FetchTiming.IMMEDIATE ) { + if ( fetchable instanceof PluralAttributeMapping && fetchTiming == FetchTiming.IMMEDIATE && joined ) { + applyFiltering( querySpec, loadQueryInfluencers, (PluralAttributeMapping) fetchable ); applyOrdering( querySpec, fetchablePath, @@ -553,6 +622,7 @@ private SelectStatement generateSelect(SubselectFetch subselect) { sqlAstCreationState.getFromClauseAccess().registerTableGroup( rootNavigablePath, rootTableGroup ); // NOTE : no need to check - we are explicitly processing a plural-attribute + applyFiltering( rootQuerySpec, loadQueryInfluencers, (PluralAttributeMapping) loadable ); applyOrdering( rootTableGroup, attributeMapping ); // generate and apply the restriction diff --git a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java index 09b5e9d13361..2ea2d131a281 100644 --- a/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java +++ b/hibernate-core/src/main/java/org/hibernate/loader/ast/internal/SingleIdLoadPlan.java @@ -122,6 +122,8 @@ public Object getBindValue() { ); } + sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings ); + final List list = JdbcSelectExecutorStandardImpl.INSTANCE.list( jdbcSelect, jdbcParameterBindings, diff --git a/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java index 25ef03db8718..bfb4757e6689 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java @@ -16,6 +16,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; import org.hibernate.AssertionFailure; import org.hibernate.FetchMode; @@ -95,12 +96,12 @@ import org.hibernate.persister.walking.spi.EntityDefinition; import org.hibernate.pretty.MessageHelper; import org.hibernate.sql.Alias; -import org.hibernate.sql.Insert; -import org.hibernate.sql.Update; import org.hibernate.sql.Delete; +import org.hibernate.sql.Insert; import org.hibernate.sql.SelectFragment; import org.hibernate.sql.SimpleSelect; import org.hibernate.sql.Template; +import org.hibernate.sql.Update; import org.hibernate.sql.results.graph.DomainResult; import org.hibernate.type.AnyType; import org.hibernate.type.AssociationType; diff --git a/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java index 9203e4297862..89f213309b4e 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java @@ -33,6 +33,7 @@ import org.hibernate.AssertionFailure; import org.hibernate.EntityMode; import org.hibernate.FetchMode; +import org.hibernate.Filter; import org.hibernate.HibernateException; import org.hibernate.JDBCException; import org.hibernate.LockMode; @@ -152,6 +153,7 @@ import org.hibernate.metamodel.spi.EntityRepresentationStrategy; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.persister.collection.CollectionPersister; +import org.hibernate.persister.collection.QueryableCollection; import org.hibernate.persister.spi.PersisterCreationContext; import org.hibernate.persister.walking.internal.EntityIdentifierDefinitionHelper; import org.hibernate.persister.walking.spi.AttributeDefinition; @@ -201,6 +203,7 @@ import org.hibernate.tuple.InMemoryValueGenerationStrategy; import org.hibernate.tuple.NonIdentifierAttribute; import org.hibernate.tuple.ValueGeneration; +import org.hibernate.tuple.entity.EntityBasedAssociationAttribute; import org.hibernate.tuple.entity.EntityMetamodel; import org.hibernate.tuple.entity.EntityTuplizer; import org.hibernate.type.AnyType; @@ -4568,8 +4571,29 @@ public boolean isAffectedByEnabledFetchProfiles(LoadQueryInfluencers loadQueryIn @Override public boolean isAffectedByEnabledFilters(LoadQueryInfluencers loadQueryInfluencers) { - return loadQueryInfluencers.hasEnabledFilters() - && filterHelper.isAffectedBy( loadQueryInfluencers.getEnabledFilters() ); + if ( loadQueryInfluencers.hasEnabledFilters() ) { + if ( filterHelper.isAffectedBy( loadQueryInfluencers.getEnabledFilters() ) ) { + return true; + } + // we still need to verify collection fields to be eagerly loaded by 'join' + final NonIdentifierAttribute[] attributes = entityMetamodel.getProperties(); + for ( NonIdentifierAttribute attribute : attributes ) { + if ( attribute instanceof EntityBasedAssociationAttribute ) { + final AssociationType associationType = ( (EntityBasedAssociationAttribute) attribute ).getType(); + if ( associationType instanceof CollectionType ) { + final Joinable joinable = associationType.getAssociatedJoinable( getFactory() ); + if ( joinable.isCollection() ) { + final QueryableCollection collectionPersister = (QueryableCollection) joinable; + if ( collectionPersister.getFetchMode() == FetchMode.JOIN + && collectionPersister.isAffectedByEnabledFilters( loadQueryInfluencers ) ) { + return true; + } + } + } + } + } + } + return false; } public final boolean isAllNull(Object[] array, int tableNumber) { diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlTreePrinter.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlTreePrinter.java index 228ef18fa634..682e01facc8d 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlTreePrinter.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlTreePrinter.java @@ -526,7 +526,12 @@ public void visitBetweenPredicate(BetweenPredicate betweenPredicate) { @Override public void visitFilterPredicate(FilterPredicate filterPredicate) { - throw new NotYetImplementedFor6Exception(); + logNode( + "filter-predicate", + () -> { + logNode( filterPredicate.getFilterFragment() ); + } + ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java index 6f89f49b47f3..82aa5289ed79 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstWalker.java @@ -70,6 +70,7 @@ import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.sql.ast.tree.select.SelectClause; import org.hibernate.sql.ast.tree.select.SortSpecification; +import org.hibernate.sql.exec.internal.AbstractJdbcParameter; import org.hibernate.sql.exec.internal.JdbcParametersImpl; import org.hibernate.sql.exec.spi.JdbcParameterBinder; import org.hibernate.type.descriptor.sql.SqlTypeDescriptorIndicators; @@ -1033,7 +1034,13 @@ public void visitBetweenPredicate(BetweenPredicate betweenPredicate) { @Override public void visitFilterPredicate(FilterPredicate filterPredicate) { - throw new NotYetImplementedFor6Exception(); + if ( filterPredicate.getFilterFragment() != null ) { + appendSql( filterPredicate.getFilterFragment() ); + for (JdbcParameter jdbcParameter : filterPredicate.getJdbcParameters()) { + parameterBinders.add( (AbstractJdbcParameter) jdbcParameter ); + jdbcParameters.addParameter( jdbcParameter ); + } + } } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTreeHelper.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTreeHelper.java index b279b6033e9b..7b9fbbacbd64 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTreeHelper.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTreeHelper.java @@ -49,6 +49,7 @@ public static Predicate combinePredicates(Predicate baseRestriction, Predicate i } else { combinedPredicate = new Junction( Junction.Nature.CONJUNCTION ); + combinedPredicate.add( baseRestriction ); } combinedPredicate.add( incomingRestriction ); diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java index 9fbd52255ac8..9bf6e5506c24 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/predicate/FilterPredicate.java @@ -6,7 +6,13 @@ */ package org.hibernate.sql.ast.tree.predicate; +import java.util.ArrayList; +import java.util.List; + +import org.hibernate.internal.FilterHelper; import org.hibernate.sql.ast.SqlAstWalker; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; +import org.hibernate.sql.exec.internal.JdbcParameterImpl; /** * Represents a filter applied to an entity/collection. @@ -16,7 +22,18 @@ * @author Steve Ebersole */ public class FilterPredicate implements Predicate { - // todo : need to "carry forward" the FilterConfiguration information into the ImprovedEntityPersister so we have access to the alias injections + private final String filterFragment; + private final List jdbcParameters; + private final List jdbcParameterTypedValues; + + public FilterPredicate(String filterFragment, List jdbcParameterTypedValues) { + this.filterFragment = filterFragment; + jdbcParameters = new ArrayList<>( jdbcParameterTypedValues.size() ); + this.jdbcParameterTypedValues = jdbcParameterTypedValues; + for (int i = 0; i < jdbcParameterTypedValues.size(); i++) { + jdbcParameters.add( new JdbcParameterImpl( null ) ); + } + } @Override public boolean isEmpty() { @@ -27,4 +44,16 @@ public boolean isEmpty() { public void accept(SqlAstWalker sqlTreeWalker) { sqlTreeWalker.visitFilterPredicate( this ); } + + public String getFilterFragment() { + return filterFragment; + } + + public List getJdbcParameters() { + return jdbcParameters; + } + + public List getJdbcParameterTypedValues() { + return jdbcParameterTypedValues; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java index 2868eabff349..ed4fb7cc8363 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/select/QuerySpec.java @@ -10,18 +10,25 @@ import java.util.List; import java.util.function.Consumer; +import org.hibernate.HibernateException; +import org.hibernate.internal.FilterHelper; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.MappingModelExpressable; import org.hibernate.query.sqm.sql.internal.DomainResultProducer; -import org.hibernate.sql.ast.spi.SqlAstTreeHelper; import org.hibernate.sql.ast.SqlAstWalker; +import org.hibernate.sql.ast.spi.SqlAstTreeHelper; import org.hibernate.sql.ast.spi.SqlExpressionResolver; import org.hibernate.sql.ast.spi.SqlSelection; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.cte.CteConsumer; import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.from.FromClause; +import org.hibernate.sql.ast.tree.predicate.FilterPredicate; import org.hibernate.sql.ast.tree.predicate.Predicate; import org.hibernate.sql.ast.tree.predicate.PredicateContainer; +import org.hibernate.sql.exec.spi.JdbcParameterBinding; +import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.results.graph.DomainResult; import org.hibernate.sql.results.graph.DomainResultCreationState; import org.hibernate.sql.results.graph.basic.BasicResult; @@ -38,6 +45,7 @@ public class QuerySpec implements SqlAstNode, PredicateContainer, Expression, Ct private final SelectClause selectClause = new SelectClause(); private Predicate whereClauseRestrictions; + private List filterPredicates; private List sortSpecifications; private Expression limitClauseExpression; private Expression offsetClauseExpression; @@ -77,6 +85,13 @@ public void applyPredicate(Predicate predicate) { this.whereClauseRestrictions = SqlAstTreeHelper.combinePredicates( this.whereClauseRestrictions, predicate ); } + public void addFilterPredicate(FilterPredicate filterPredicate) { + if ( filterPredicates == null ) { + filterPredicates = new ArrayList<>(); + } + filterPredicates.add( filterPredicate ); + } + public List getSortSpecifications() { return sortSpecifications; } @@ -155,4 +170,32 @@ public DomainResult createDomainResult(String resultVariable, DomainResultCreati descriptor ); } + + public void bindFilterPredicateParameters(JdbcParameterBindings jdbcParameterBindings) { + if ( filterPredicates != null && !filterPredicates.isEmpty() ) { + for ( FilterPredicate filterPredicate : filterPredicates ) { + for ( int i = 0; i < filterPredicate.getJdbcParameters().size(); i++ ) { + final JdbcParameter parameter = filterPredicate.getJdbcParameters().get( i ); + final FilterHelper.TypedValue parameterTypedValue = filterPredicate.getJdbcParameterTypedValues().get( i ); + if ( !(parameterTypedValue.getType() instanceof JdbcMapping ) ) { + throw new HibernateException( String.format( "Filter parameter type [%s] did not implement JdbcMapping", parameterTypedValue.getType() ) ); + } + jdbcParameterBindings.addBinding( + parameter, + new JdbcParameterBinding() { + @Override + public JdbcMapping getBindType() { + return (JdbcMapping) parameterTypedValue.getType(); + } + + @Override + public Object getBindValue() { + return parameterTypedValue.getValue(); + } + } + ); + } + } + } + } } diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterBasicsTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterBasicsTests.java new file mode 100644 index 000000000000..590df99aed95 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterBasicsTests.java @@ -0,0 +1,281 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.loading.filter; + +import java.util.ArrayList; +import java.util.List; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; +import org.hibernate.cfg.AvailableSettings; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterBasicsTests.Client.class, + FilterBasicsTests.Account.class + } +) +@ServiceRegistry( + settings = @ServiceRegistry.Setting( + name = AvailableSettings.HBM2DDL_AUTO, + value = "create-drop" + ) +) +@SessionFactory +public class FilterBasicsTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnEntity(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ) + .setParameter( "active", true ); + } + Account account1 = session.find( Account.class, 1L ); + Account account2 = session.find( Account.class, 2L ); + assertThat( account1, notNullValue() ); + assertThat( account2, notNullValue() ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ) + .setParameter( "active", true ); + } + Client client = session.find( Client.class, 1L ); + + if ( enableFilter ) { + assertThat( client.getAccounts().size(), is( 2 ) ); + } + else { + assertThat( client.getAccounts().size(), is( 3 ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + public static class Client { + + @Id + private Long id; + + private String name; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterJoinTableTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterJoinTableTests.java new file mode 100644 index 000000000000..290d1f445889 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterJoinTableTests.java @@ -0,0 +1,203 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.loading.filter; + +import java.util.ArrayList; +import java.util.List; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.Id; +import javax.persistence.JoinTable; +import javax.persistence.ManyToMany; +import javax.persistence.OrderColumn; + +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.FilterJoinTable; +import org.hibernate.annotations.ParamDef; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterJoinTableTests.Client.class, + FilterJoinTableTests.Account.class + } +) +@SessionFactory +public class FilterJoinTableTests { + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + ); + + session.persist( client ); + } ); + } + + @Test + void testLoadFilterOnCollectionField(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.enableFilter( "firstAccounts" ) + .setParameter( "maxOrderId", 1); + Client client = session.find( Client.class, 1L ); + assertThat( client.getAccounts().size(), is( 2 ) ); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + @FilterDef( + name="firstAccounts", + parameters=@ParamDef( + name="maxOrderId", + type="int" + ) + ) + @Filter( + name="firstAccounts", + condition="order_id <= :maxOrderId" + ) + public static class Client { + + @Id + private Long id; + + private String name; + + @ManyToMany(cascade = CascadeType.ALL) + @JoinTable + @OrderColumn(name = "order_id") + @FilterJoinTable( + name="firstAccounts", + condition="order_id <= :maxOrderId" + ) + private List accounts = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccounts() { + return accounts; + } + + public void addAccount(Account account) { + this.accounts.add( account ); + } + } + + @Entity(name = "Account") + public static class Account { + + @Id + private Long id; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterOnEagerLoadedCollectionTests.java b/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterOnEagerLoadedCollectionTests.java new file mode 100644 index 000000000000..4ae8d790293f --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/loading/filter/FilterOnEagerLoadedCollectionTests.java @@ -0,0 +1,304 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html + */ +package org.hibernate.orm.test.loading.filter; + +import java.util.ArrayList; +import java.util.List; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.FetchType; +import javax.persistence.Id; +import javax.persistence.ManyToOne; +import javax.persistence.OneToMany; + +import org.hibernate.annotations.Fetch; +import org.hibernate.annotations.FetchMode; +import org.hibernate.annotations.Filter; +import org.hibernate.annotations.FilterDef; +import org.hibernate.annotations.ParamDef; +import org.hibernate.cfg.AvailableSettings; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hibernate.testing.hamcrest.CollectionMatchers.hasSize; +import static org.junit.Assert.assertThat; + +/** + * @author Nathan Xu + */ +@DomainModel( + annotatedClasses = { + FilterOnEagerLoadedCollectionTests.Client.class, + FilterOnEagerLoadedCollectionTests.Account.class + } +) +@ServiceRegistry( + settings = @ServiceRegistry.Setting( + name = AvailableSettings.HBM2DDL_AUTO, + value = "create-drop" + ) +) +@SessionFactory +public class FilterOnEagerLoadedCollectionTests implements SessionFactoryScopeAware { + + private SessionFactoryScope scope; + + @Override + public void injectSessionFactoryScope(SessionFactoryScope scope) { + this.scope = scope; + } + + @BeforeEach + void setUp(SessionFactoryScope scope) { + scope.inTransaction( session -> { + Client client = new Client() + .setId( 1L ) + .setName( "John Doe" ); + + client.addAccount( + new Account() + .setId( 1L ) + .setType( AccountType.CREDIT ) + .setAmount( 5000d ) + .setRate( 1.25 / 100 ) + .setActive( true ) + ); + + client.addAccount( + new Account() + .setId( 2L ) + .setType( AccountType.DEBIT ) + .setAmount( 0d ) + .setRate( 1.05 / 100 ) + .setActive( false ) + ); + + client.addAccount( + new Account() + .setType( AccountType.DEBIT ) + .setId( 3L ) + .setAmount( 250d ) + .setRate( 1.05 / 100 ) + .setActive( true ) + ); + session.persist( client ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnEntity(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ) + .setParameter( "active", true ); + } + Account account1 = session.find( Account.class, 1L ); + Account account2 = session.find( Account.class, 2L ); + assertThat( account1, notNullValue() ); + assertThat( account2, notNullValue() ); + } ); + } + + @ParameterizedTest + @ValueSource( strings = { "true", "false" } ) + void testLoadFilterOnCollectionField(boolean enableFilter) { + scope.inTransaction( session -> { + if ( enableFilter ) { + session.enableFilter( "activeAccount" ) + .setParameter( "active", true ); + } + Client client = session.find( Client.class, 1L ); + + if ( enableFilter ) { + assertThat( client.getAccountsFetchedBySelect(), hasSize( 2 ) ); + assertThat( client.getAccountsFetchedByJoin(), hasSize( 2 ) ); + } + else { + assertThat( client.getAccountsFetchedBySelect(), hasSize( 3 ) ); + assertThat( client.getAccountsFetchedByJoin(), hasSize( 3 ) ); + } + } ); + } + + @AfterEach + void tearDown(SessionFactoryScope scope) { + scope.inTransaction( session -> { + session.createQuery( "delete from Account" ).executeUpdate(); + session.createQuery( "delete from Client" ).executeUpdate(); + } ); + } + + public enum AccountType { + DEBIT, + CREDIT + } + + @Entity(name = "Client") + public static class Client { + + @Id + private Long id; + + private String name; + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL, + fetch = FetchType.EAGER + ) + @Fetch( FetchMode.SELECT ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accountsFetchedBySelect = new ArrayList<>(); + + @OneToMany( + mappedBy = "client", + cascade = CascadeType.ALL, + fetch = FetchType.EAGER + ) + @Fetch( FetchMode.JOIN ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + private List accountsFetchedByJoin = new ArrayList<>(); + + public Long getId() { + return id; + } + + public Client setId(Long id) { + this.id = id; + return this; + } + + public String getName() { + return name; + } + + public Client setName(String name) { + this.name = name; + return this; + } + + public List getAccountsFetchedBySelect() { + return accountsFetchedBySelect; + } + + public List getAccountsFetchedByJoin() { + return accountsFetchedByJoin; + } + + public void addAccount(Account account) { + account.setClient( this ); + this.accountsFetchedBySelect.add( account ); + this.accountsFetchedByJoin.add( account ); + } + } + + @Entity(name = "Account") + @FilterDef( + name="activeAccount", + parameters = @ParamDef( + name="active", + type="boolean" + ) + ) + @Filter( + name="activeAccount", + condition="active_status = :active" + ) + public static class Account { + + @Id + private Long id; + + @ManyToOne(fetch = FetchType.LAZY) + private Client client; + + @Column(name = "account_type") + @Enumerated(EnumType.STRING) + private AccountType type; + + private Double amount; + + private Double rate; + + @Column(name = "active_status") + private boolean active; + + public Long getId() { + return id; + } + + public Account setId(Long id) { + this.id = id; + return this; + } + + public Client getClient() { + return client; + } + + public Account setClient(Client client) { + this.client = client; + return this; + } + + public AccountType getType() { + return type; + } + + public Account setType(AccountType type) { + this.type = type; + return this; + } + + public Double getAmount() { + return amount; + } + + public Account setAmount(Double amount) { + this.amount = amount; + return this; + } + + public Double getRate() { + return rate; + } + + public Account setRate(Double rate) { + this.rate = rate; + return this; + } + + public boolean isActive() { + return active; + } + + public Account setActive(boolean active) { + this.active = active; + return this; + } + } + +}