Skip to content

Commit

Permalink
implement @filter for loader
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanQingyangXu authored and dreab8 committed Apr 14, 2020
1 parent c23ded5 commit 34d5a2a
Show file tree
Hide file tree
Showing 17 changed files with 1,065 additions and 88 deletions.
Expand Up @@ -975,7 +975,7 @@ private void bindFilters(boolean hasAssociationTable) {
FilterJoinTable simpleFilterJoinTable = property.getAnnotation( FilterJoinTable.class );
if ( simpleFilterJoinTable != null ) {
if ( hasAssociationTable ) {
collection.addFilter(simpleFilterJoinTable.name(), simpleFilterJoinTable.condition(),
collection.addManyToManyFilter(simpleFilterJoinTable.name(), simpleFilterJoinTable.condition(),
simpleFilterJoinTable.deduceAliasInjectionPoints(),
toAliasTableMap(simpleFilterJoinTable.aliases()), toAliasEntityMap(simpleFilterJoinTable.aliases()));
}
Expand All @@ -990,7 +990,7 @@ private void bindFilters(boolean hasAssociationTable) {
if ( filterJoinTables != null ) {
for (FilterJoinTable filter : filterJoinTables.value()) {
if ( hasAssociationTable ) {
collection.addFilter(filter.name(), filter.condition(),
collection.addManyToManyFilter(filter.name(), filter.condition(),
filter.deduceAliasInjectionPoints(),
toAliasTableMap(filter.aliases()), toAliasEntityMap(filter.aliases()));
}
Expand Down
Expand Up @@ -515,80 +515,6 @@ public void setPassDistinctThrough(boolean passDistinctThrough) {
this.passDistinctThrough = passDistinctThrough;
}

public void processFilters(String sql, SharedSessionContractImplementor session) {
processFilters( sql, session.getLoadQueryInfluencers().getEnabledFilters(), session.getFactory() );
}

@SuppressWarnings( {"unchecked"})
public void processFilters(String sql, Map filters, SessionFactoryImplementor factory) {
if ( filters.size() == 0 || !sql.contains( HQL_VARIABLE_PREFIX ) ) {
// HELLA IMPORTANT OPTIMIZATION!!!
processedPositionalParameterValues = getPositionalParameterValues();
processedPositionalParameterTypes = getPositionalParameterTypes();
processedSQL = sql;
}
else {
throw new NotYetImplementedFor6Exception( getClass() );
// final StringTokenizer tokens = new StringTokenizer( sql, SYMBOLS, true );
// StringBuilder result = new StringBuilder();
// List parameters = new ArrayList();
// List parameterTypes = new ArrayList();
// int positionalIndex = 0;
// while ( tokens.hasMoreTokens() ) {
// final String token = tokens.nextToken();
// if ( token.startsWith( ParserHelper.HQL_VARIABLE_PREFIX ) ) {
// final String filterParameterName = token.substring( 1 );
// final String[] parts = LoadQueryInfluencers.parseFilterParameterName( filterParameterName );
// final FilterImpl filter = (FilterImpl) filters.get( parts[0] );
// final Object value = filter.getParameter( parts[1] );
// final Type type = filter.getFilterDefinition().getParameterType( parts[1] );
// if ( value != null && Collection.class.isAssignableFrom( value.getClass() ) ) {
// Iterator itr = ( (Collection) value ).iterator();
// while ( itr.hasNext() ) {
// final Object elementValue = itr.next();
// result.append( '?' );
// parameters.add( elementValue );
// parameterTypes.add( type );
// if ( itr.hasNext() ) {
// result.append( ", " );
// }
// }
// }
// else {
// result.append( '?' );
// parameters.add( value );
// parameterTypes.add( type );
// }
// }
// else {
// result.append( token );
// if ( "?".equals( token ) && positionalIndex < getPositionalParameterValues().length ) {
// final Type type = getPositionalParameterTypes()[positionalIndex];
// if ( type.isComponentType() ) {
// // should process tokens till reaching the number of "?" corresponding to the
// // numberOfParametersCoveredBy of the compositeType
// int paramIndex = 1;
// final int numberOfParametersCoveredBy = getNumberOfParametersCoveredBy( ((ComponentType) type).getSubtypes() );
// while ( paramIndex < numberOfParametersCoveredBy ) {
// final String nextToken = tokens.nextToken();
// if ( "?".equals( nextToken ) ) {
// paramIndex++;
// }
// result.append( nextToken );
// }
// }
// parameters.add( getPositionalParameterValues()[positionalIndex] );
// parameterTypes.add( type );
// positionalIndex++;
// }
// }
// }
// processedPositionalParameterValues = parameters.toArray();
// processedPositionalParameterTypes = ( Type[] ) parameterTypes.toArray( new Type[parameterTypes.size()] );
// processedSQL = result.toString();
}
}

private int getNumberOfParametersCoveredBy(Type[] subtypes) {
int numberOfParameters = 0;
for ( Type type : subtypes ) {
Expand Down
Expand Up @@ -6,13 +6,19 @@
*/
package org.hibernate.internal;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.hibernate.Filter;
import org.hibernate.HibernateException;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.internal.util.StringHelper;
import org.hibernate.internal.util.collections.CollectionHelper;
import org.hibernate.sql.Template;
import org.hibernate.type.Type;

import static org.hibernate.internal.util.StringHelper.safeInterning;

Expand All @@ -21,9 +27,12 @@
*
* @author Steve Ebersole
* @author Rob Worsnop
* @author Nathan Xu
*/
public class FilterHelper {

private static Pattern FILTER_PARAMETER_PATTERN = Pattern.compile( ":(\\w+)\\.(\\w+)" );

private final String[] filterNames;
private final String[] filterConditions;
private final boolean[] filterAutoAliasFlags;
Expand Down Expand Up @@ -99,7 +108,10 @@ public void render(StringBuilder buffer, FilterAliasGenerator aliasGenerator, Ma
if ( enabledFilters.containsKey( filterNames[i] ) ) {
final String condition = filterConditions[i];
if ( StringHelper.isNotEmpty( condition ) ) {
buffer.append( " and " ).append( render( aliasGenerator, i ) );
if ( buffer.length() > 0 ) {
buffer.append( " and " );
}
buffer.append( render( aliasGenerator, i ) );
}
}
}
Expand Down Expand Up @@ -128,4 +140,68 @@ else if ( isTableFromPersistentClass( aliasTableMap ) ) {
return condition;
}
}

public static class TypedValue {
private final Type type;
private final Object value;

public TypedValue(Type type, Object value) {
this.type = type;
this.value = value;
}

public Type getType() {
return type;
}

public Object getValue() {
return value;
}
}

public static class TransformResult {
private final String transformedFilterFragment;
private final List<TypedValue> parameters;

public TransformResult(
String transformedFilterFragment,
List<TypedValue> parameters) {
this.transformedFilterFragment = transformedFilterFragment;
this.parameters = parameters;
}

public String getTransformedFilterFragment() {
return transformedFilterFragment;
}

public List<TypedValue> getParameters() {
return parameters;
}
}

public static TransformResult transformToPositionalParameters(String filterFragment, Map<String, Filter> enabledFilters) {
final Matcher matcher = FILTER_PARAMETER_PATTERN.matcher( filterFragment );
final StringBuilder sb = new StringBuilder();
int pos = 0;
final List<TypedValue> parameters = new ArrayList<>( matcher.groupCount() );
while( matcher.find() ) {
sb.append( filterFragment, pos, matcher.start() );
pos = matcher.end();
sb.append( "?" );
final String filterName = matcher.group( 1 );
final String parameterName = matcher.group( 2 );
final FilterImpl enabledFilter = (FilterImpl) enabledFilters.get( filterName );
if ( enabledFilter == null ) {
throw new HibernateException( String.format( "unknown filter [%s]", filterName ) );
}
final Type parameterType = enabledFilter.getFilterDefinition().getParameterType( parameterName );
final Object parameterValue = enabledFilter.getParameter( parameterName );
if ( parameterValue == null ) {
throw new HibernateException( String.format( "unknown parameter [%s] for filter [%s]", parameterName, filterName ) );
}
parameters.add( new TypedValue( parameterType, parameterValue ) );
}
sb.append( filterFragment, pos, filterFragment.length() );
return new TransformResult( sb.toString(), parameters );
}
}
Expand Up @@ -26,11 +26,11 @@
import org.hibernate.query.spi.QueryParameterBindings;
import org.hibernate.sql.ast.Clause;
import org.hibernate.sql.ast.SqlAstTranslatorFactory;
import org.hibernate.sql.ast.tree.expression.JdbcParameter;
import org.hibernate.sql.ast.tree.select.SelectStatement;
import org.hibernate.sql.exec.internal.JdbcParameterBindingsImpl;
import org.hibernate.sql.exec.spi.Callback;
import org.hibernate.sql.exec.spi.ExecutionContext;
import org.hibernate.sql.ast.tree.expression.JdbcParameter;
import org.hibernate.sql.exec.spi.JdbcParameterBinding;
import org.hibernate.sql.exec.spi.JdbcParameterBindings;
import org.hibernate.sql.exec.spi.JdbcSelect;
Expand Down Expand Up @@ -163,6 +163,9 @@ private void batchLoad(
final JdbcSelect jdbcSelect = sqlAstTranslatorFactory.buildSelectTranslator( sessionFactory ).translate( sqlAst );

final JdbcParameterBindings jdbcParameterBindings = new JdbcParameterBindingsImpl( keyJdbcCount * smallBatchLength );

sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings );

final Iterator<JdbcParameter> paramItr = jdbcParameters.iterator();

for ( int i = smallBatchStart; i < smallBatchStart + smallBatchLength; i++ ) {
Expand Down
Expand Up @@ -25,11 +25,11 @@
import org.hibernate.query.spi.QueryParameterBindings;
import org.hibernate.sql.ast.Clause;
import org.hibernate.sql.ast.SqlAstTranslatorFactory;
import org.hibernate.sql.ast.tree.expression.JdbcParameter;
import org.hibernate.sql.ast.tree.select.SelectStatement;
import org.hibernate.sql.exec.internal.JdbcParameterBindingsImpl;
import org.hibernate.sql.exec.spi.Callback;
import org.hibernate.sql.exec.spi.ExecutionContext;
import org.hibernate.sql.ast.tree.expression.JdbcParameter;
import org.hibernate.sql.exec.spi.JdbcParameterBinding;
import org.hibernate.sql.exec.spi.JdbcParameterBindings;
import org.hibernate.sql.exec.spi.JdbcSelect;
Expand Down Expand Up @@ -127,6 +127,8 @@ public Object getBindValue() {
);
assert !paramItr.hasNext();

sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings );

jdbcServices.getJdbcSelectExecutor().list(
jdbcSelect,
jdbcParameterBindings,
Expand Down
Expand Up @@ -23,6 +23,8 @@
import org.hibernate.engine.spi.LoadQueryInfluencers;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.engine.spi.SubselectFetch;
import org.hibernate.internal.FilterHelper;
import org.hibernate.internal.FilterHelper.TransformResult;
import org.hibernate.loader.ast.spi.Loadable;
import org.hibernate.loader.ast.spi.Loader;
import org.hibernate.metamodel.mapping.BasicValuedModelPart;
Expand All @@ -33,6 +35,8 @@
import org.hibernate.metamodel.mapping.PluralAttributeMapping;
import org.hibernate.metamodel.mapping.internal.SimpleForeignKeyDescriptor;
import org.hibernate.metamodel.mapping.ordering.OrderByFragment;
import org.hibernate.persister.collection.AbstractCollectionPersister;
import org.hibernate.persister.entity.Joinable;
import org.hibernate.query.ComparisonOperator;
import org.hibernate.query.NavigablePath;
import org.hibernate.sql.ast.spi.SimpleFromClauseAccessImpl;
Expand All @@ -44,8 +48,11 @@
import org.hibernate.sql.ast.tree.expression.JdbcParameter;
import org.hibernate.sql.ast.tree.expression.SqlTuple;
import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.ast.tree.from.TableGroupJoin;
import org.hibernate.sql.ast.tree.from.TableReference;
import org.hibernate.sql.ast.tree.from.TableReferenceJoin;
import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate;
import org.hibernate.sql.ast.tree.predicate.FilterPredicate;
import org.hibernate.sql.ast.tree.predicate.InListPredicate;
import org.hibernate.sql.ast.tree.predicate.InSubQueryPredicate;
import org.hibernate.sql.ast.tree.select.QuerySpec;
Expand Down Expand Up @@ -222,6 +229,7 @@ private SelectStatement generateSelect() {
sqlAstCreationState.getFromClauseAccess().registerTableGroup( rootNavigablePath, rootTableGroup );

if ( loadable instanceof PluralAttributeMapping ) {
applyFiltering( rootQuerySpec, loadQueryInfluencers, (PluralAttributeMapping) loadable );
applyOrdering( rootTableGroup, (PluralAttributeMapping) loadable );
}

Expand Down Expand Up @@ -372,6 +380,66 @@ private void applyKeyRestriction(
}
}

private void applyFiltering(
QuerySpec querySpec,
LoadQueryInfluencers loadQueryInfluencers,
PluralAttributeMapping pluralAttributeMapping) {
if ( loadQueryInfluencers.hasEnabledFilters() ) {
final Joinable joinable = pluralAttributeMapping
.getCollectionDescriptor()
.getCollectionType()
.getAssociatedJoinable( creationContext.getSessionFactory() );
assert joinable instanceof AbstractCollectionPersister;
final AbstractCollectionPersister collectionPersister = (AbstractCollectionPersister) joinable;
querySpec.getFromClause().getRoots().forEach( tableGroup -> consumeTableAliasByTableExpression(
tableGroup,
joinable.getTableName(),
alias -> {
final boolean isManyToMany = collectionPersister.isManyToMany();
String filterFragment;
if ( isManyToMany ) {
filterFragment = collectionPersister.getManyToManyFilterFragment(
alias,
loadQueryInfluencers.getEnabledFilters()
);
}
else {
filterFragment = collectionPersister.filterFragment(
alias,
loadQueryInfluencers.getEnabledFilters()
);
}
final TransformResult transformResult = FilterHelper.transformToPositionalParameters(
filterFragment, loadQueryInfluencers.getEnabledFilters()
);
filterFragment = transformResult.getTransformedFilterFragment();
final FilterPredicate filterPredicate = new FilterPredicate(
filterFragment, transformResult.getParameters()
);
querySpec.applyPredicate( filterPredicate );
querySpec.addFilterPredicate( filterPredicate );
}
)
);
}
}

private void consumeTableAliasByTableExpression(TableGroup tableGroup, String tableExpression, Consumer<String> aliasConsumer) {
if ( tableExpression.equals( tableGroup.getPrimaryTableReference().getTableExpression() ) ) {
aliasConsumer.accept( tableGroup.getPrimaryTableReference().getIdentificationVariable() );
}
else {
for ( TableReferenceJoin referenceJoin : tableGroup.getTableReferenceJoins() ) {
if ( tableExpression.equals( referenceJoin.getJoinedTableReference().getTableExpression() ) ) {
aliasConsumer.accept( referenceJoin.getJoinedTableReference().getIdentificationVariable() );
}
}
for ( TableGroupJoin tableGroupJoin : tableGroup.getTableGroupJoins() ) {
consumeTableAliasByTableExpression( tableGroupJoin.getJoinedGroup(), tableExpression, aliasConsumer );
}
}
}

private void applyOrdering(TableGroup tableGroup, PluralAttributeMapping pluralAttributeMapping) {
if ( pluralAttributeMapping.getOrderByFragment() != null ) {
applyOrdering( tableGroup, pluralAttributeMapping.getOrderByFragment() );
Expand Down Expand Up @@ -478,7 +546,8 @@ else if ( fetchDepth > maximumFetchDepth ) {
);
fetches.add( fetch );

if ( fetchable instanceof PluralAttributeMapping && fetchTiming == FetchTiming.IMMEDIATE ) {
if ( fetchable instanceof PluralAttributeMapping && fetchTiming == FetchTiming.IMMEDIATE && joined ) {
applyFiltering( querySpec, loadQueryInfluencers, (PluralAttributeMapping) fetchable );
applyOrdering(
querySpec,
fetchablePath,
Expand Down Expand Up @@ -553,6 +622,7 @@ private SelectStatement generateSelect(SubselectFetch subselect) {
sqlAstCreationState.getFromClauseAccess().registerTableGroup( rootNavigablePath, rootTableGroup );

// NOTE : no need to check - we are explicitly processing a plural-attribute
applyFiltering( rootQuerySpec, loadQueryInfluencers, (PluralAttributeMapping) loadable );
applyOrdering( rootTableGroup, attributeMapping );

// generate and apply the restriction
Expand Down
Expand Up @@ -122,6 +122,8 @@ public Object getBindValue() {
);
}

sqlAst.getQuerySpec().bindFilterPredicateParameters( jdbcParameterBindings );

final List list = JdbcSelectExecutorStandardImpl.INSTANCE.list(
jdbcSelect,
jdbcParameterBindings,
Expand Down

0 comments on commit 34d5a2a

Please sign in to comment.