Skip to content

Commit

Permalink
HHH-16858 improve typechecking for comparisons/assignments
Browse files Browse the repository at this point in the history
In particular, correctly typecheck comparisons between enums
and other enums, and literal integers / strings. Actually
I'm not a great fan of comparing enums with int/string literals
but since we used to support it in 5, and kinda mostly support
it in earlier releases of 6, on balance we might as well continue
to allow it.
  • Loading branch information
gavinking committed Jun 28, 2023
1 parent dfa26e0 commit 1a5b75f
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import org.hibernate.type.descriptor.jdbc.JdbcType;

/**
* Specialization of DomainType for types that can be used as a
* parameter output for a {@link org.hibernate.procedure.ProcedureCall}
* Specialization of {@link org.hibernate.metamodel.model.domain.DomainType} for types that
* can be used as a parameter output for a {@link org.hibernate.procedure.ProcedureCall}.
*
* @apiNote We assume a type that maps to exactly one SQL value, hence {@link #getJdbcType()}
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,40 +434,21 @@ public SqmSelectStatement<R> visitSelectStatement(HqlParser.SelectStatementConte

@Override
public SqmRoot<R> visitTargetEntity(HqlParser.TargetEntityContext dmlTargetContext) {
final HqlParser.EntityNameContext entityNameContext = (HqlParser.EntityNameContext) dmlTargetContext.getChild( 0 );
final String identificationVariable;
if ( dmlTargetContext.getChildCount() == 1 ) {
identificationVariable = null;
}
else {
identificationVariable = applyJpaCompliance(
visitVariable(
(HqlParser.VariableContext) dmlTargetContext.getChild( 1 )
)
);
}
final HqlParser.EntityNameContext entityNameContext = dmlTargetContext.entityName();
final HqlParser.VariableContext variable = dmlTargetContext.variable();
//noinspection unchecked
return new SqmRoot<>(
(EntityDomainType<R>) visitEntityName( entityNameContext ),
identificationVariable,
variable == null ? null : applyJpaCompliance( visitVariable( variable ) ),
false,
creationContext.getNodeBuilder()
);
}

@Override
public SqmInsertStatement<R> visitInsertStatement(HqlParser.InsertStatementContext ctx) {
final int dmlTargetIndex;
if ( ctx.getChild( 1 ) instanceof HqlParser.TargetEntityContext ) {
dmlTargetIndex = 1;
}
else {
dmlTargetIndex = 2;
}
final HqlParser.TargetEntityContext dmlTargetContext = (HqlParser.TargetEntityContext) ctx.getChild( dmlTargetIndex );
final HqlParser.TargetFieldsContext targetFieldsSpecContext = (HqlParser.TargetFieldsContext) ctx.getChild(
dmlTargetIndex + 1
);
final HqlParser.TargetEntityContext dmlTargetContext = ctx.targetEntity();
final HqlParser.TargetFieldsContext targetFieldsSpecContext = ctx.targetFields();
final SqmRoot<R> root = visitTargetEntity( dmlTargetContext );
if ( root.getModel() instanceof SqmPolymorphicRootDescriptor<?> ) {
throw new SemanticException(
Expand All @@ -480,7 +461,8 @@ public SqmInsertStatement<R> visitInsertStatement(HqlParser.InsertStatementConte

final HqlParser.QueryExpressionContext queryExpressionContext = ctx.queryExpression();
if ( queryExpressionContext != null ) {
final SqmInsertSelectStatement<R> insertStatement = new SqmInsertSelectStatement<>( root, creationContext.getNodeBuilder() );
final SqmInsertSelectStatement<R> insertStatement =
new SqmInsertSelectStatement<>( root, creationContext.getNodeBuilder() );
parameterCollector = insertStatement;
final SqmDmlCreationProcessingState processingState = new SqmDmlCreationProcessingState(
insertStatement,
Expand Down Expand Up @@ -517,7 +499,8 @@ public SqmInsertStatement<R> visitInsertStatement(HqlParser.InsertStatementConte

}
else {
final SqmInsertValuesStatement<R> insertStatement = new SqmInsertValuesStatement<>( root, creationContext.getNodeBuilder() );
final SqmInsertValuesStatement<R> insertStatement =
new SqmInsertValuesStatement<>( root, creationContext.getNodeBuilder() );
parameterCollector = insertStatement;
final SqmDmlCreationProcessingState processingState = new SqmDmlCreationProcessingState(
insertStatement,
Expand Down Expand Up @@ -553,9 +536,8 @@ public SqmInsertStatement<R> visitInsertStatement(HqlParser.InsertStatementConte

@Override
public SqmUpdateStatement<R> visitUpdateStatement(HqlParser.UpdateStatementContext ctx) {
final boolean versioned = !( ctx.getChild( 1 ) instanceof HqlParser.TargetEntityContext );
final int dmlTargetIndex = versioned ? 2 : 1;
final HqlParser.TargetEntityContext dmlTargetContext = (HqlParser.TargetEntityContext) ctx.getChild( dmlTargetIndex );
final boolean versioned = ctx.VERSIONED() != null;
final HqlParser.TargetEntityContext dmlTargetContext = ctx.targetEntity();
final SqmRoot<R> root = visitTargetEntity( dmlTargetContext );
if ( root.getModel() instanceof SqmPolymorphicRootDescriptor<?> ) {
throw new SemanticException(
Expand All @@ -577,12 +559,12 @@ public SqmUpdateStatement<R> visitUpdateStatement(HqlParser.UpdateStatementConte

try {
updateStatement.versioned( versioned );
final HqlParser.SetClauseContext setClauseCtx = (HqlParser.SetClauseContext) ctx.getChild( dmlTargetIndex + 1 );
final HqlParser.SetClauseContext setClauseCtx = ctx.setClause();
for ( ParseTree subCtx : setClauseCtx.children ) {
if ( subCtx instanceof HqlParser.AssignmentContext ) {
final HqlParser.AssignmentContext assignmentContext = (HqlParser.AssignmentContext) subCtx;
//noinspection unchecked
final SqmPath<Object> targetPath = (SqmPath<Object>) consumeDomainPath( (HqlParser.SimplePathContext) assignmentContext.getChild( 0 ) );
final SqmPath<Object> targetPath = (SqmPath<Object>) consumeDomainPath( assignmentContext.simplePath() );
final Class<?> targetPathJavaType = targetPath.getJavaType();
final boolean isEnum = targetPathJavaType != null && targetPathJavaType.isEnum();
final ParseTree rightSide = assignmentContext.getChild( 2 );
Expand All @@ -604,10 +586,9 @@ public SqmUpdateStatement<R> visitUpdateStatement(HqlParser.UpdateStatementConte
}
}

if ( dmlTargetIndex + 2 <= ctx.getChildCount() ) {
updateStatement.applyPredicate(
visitWhereClause( (HqlParser.WhereClauseContext) ctx.getChild( dmlTargetIndex + 2 ) )
);
final HqlParser.WhereClauseContext whereClauseContext = ctx.whereClause();
if ( whereClauseContext != null ) {
updateStatement.applyPredicate( visitWhereClause( whereClauseContext ) );
}

return updateStatement;
Expand All @@ -619,14 +600,7 @@ public SqmUpdateStatement<R> visitUpdateStatement(HqlParser.UpdateStatementConte

@Override
public SqmDeleteStatement<R> visitDeleteStatement(HqlParser.DeleteStatementContext ctx) {
final int dmlTargetIndex;
if ( ctx.getChild( 1 ) instanceof HqlParser.TargetEntityContext ) {
dmlTargetIndex = 1;
}
else {
dmlTargetIndex = 2;
}
final HqlParser.TargetEntityContext dmlTargetContext = (HqlParser.TargetEntityContext) ctx.getChild( dmlTargetIndex );
final HqlParser.TargetEntityContext dmlTargetContext = ctx.targetEntity();
final SqmRoot<R> root = visitTargetEntity( dmlTargetContext );

final SqmDeleteStatement<R> deleteStatement = new SqmDeleteStatement<>( root, SqmQuerySource.HQL, creationContext.getNodeBuilder() );
Expand All @@ -642,10 +616,9 @@ public SqmDeleteStatement<R> visitDeleteStatement(HqlParser.DeleteStatementConte

processingStateStack.push( sqmDeleteCreationState );
try {
if ( dmlTargetIndex + 1 <= ctx.getChildCount() ) {
deleteStatement.applyPredicate(
visitWhereClause( (HqlParser.WhereClauseContext) ctx.getChild( dmlTargetIndex + 1 ) )
);
final HqlParser.WhereClauseContext whereClauseContext = ctx.whereClause();
if ( whereClauseContext != null ) {
deleteStatement.applyPredicate( visitWhereClause( whereClauseContext ) );
}

return deleteStatement;
Expand Down Expand Up @@ -2472,50 +2445,37 @@ private SqmComparisonPredicate createComparisonPredicate(
HqlParser.ExpressionContext rightExpressionContext) {
final SqmExpression<?> right;
final SqmExpression<?> left;
switch ( comparisonOperator ) {
case EQUAL:
case NOT_EQUAL:
case DISTINCT_FROM:
case NOT_DISTINCT_FROM: {
Map<Class<?>, Enum<?>> possibleEnumValues;
if ( ( possibleEnumValues = getPossibleEnumValues( leftExpressionContext ) ) != null ) {
right = (SqmExpression<?>) rightExpressionContext.accept( this );
left = resolveEnumShorthandLiteral(
leftExpressionContext,
possibleEnumValues,
right.getJavaType()
);
break;
}
else if ( ( possibleEnumValues = getPossibleEnumValues( rightExpressionContext ) ) != null ) {
left = (SqmExpression<?>) leftExpressionContext.accept( this );
right = resolveEnumShorthandLiteral(
rightExpressionContext,
possibleEnumValues,
left.getJavaType()
);
break;
}
final SqmExpression<?> l = (SqmExpression<?>) leftExpressionContext.accept( this );
final SqmExpression<?> r = (SqmExpression<?>) rightExpressionContext.accept( this );
if ( l instanceof AnyDiscriminatorSqmPath && r instanceof SqmLiteralEntityType ) {
left = l;
right = createDiscriminatorValue( (AnyDiscriminatorSqmPath<?>) left, rightExpressionContext );
}
else if ( r instanceof AnyDiscriminatorSqmPath && l instanceof SqmLiteralEntityType ) {
left = createDiscriminatorValue( (AnyDiscriminatorSqmPath<?>) r, leftExpressionContext );
right = r;
}
else {
left = l;
right = r;
}
break;
Map<Class<?>, Enum<?>> possibleEnumValues;
if ( ( possibleEnumValues = getPossibleEnumValues( leftExpressionContext ) ) != null ) {
right = (SqmExpression<?>) rightExpressionContext.accept( this );
left = resolveEnumShorthandLiteral(
leftExpressionContext,
possibleEnumValues,
right.getJavaType()
);
}
else if ( ( possibleEnumValues = getPossibleEnumValues( rightExpressionContext ) ) != null ) {
left = (SqmExpression<?>) leftExpressionContext.accept( this );
right = resolveEnumShorthandLiteral(
rightExpressionContext,
possibleEnumValues,
left.getJavaType()
);
}
else {
final SqmExpression<?> l = (SqmExpression<?>) leftExpressionContext.accept( this );
final SqmExpression<?> r = (SqmExpression<?>) rightExpressionContext.accept( this );
if ( l instanceof AnyDiscriminatorSqmPath && r instanceof SqmLiteralEntityType ) {
left = l;
right = createDiscriminatorValue( (AnyDiscriminatorSqmPath<?>) left, rightExpressionContext );
}
default: {
left = (SqmExpression<?>) leftExpressionContext.accept( this );
right = (SqmExpression<?>) rightExpressionContext.accept( this );
break;
else if ( r instanceof AnyDiscriminatorSqmPath && l instanceof SqmLiteralEntityType ) {
left = createDiscriminatorValue( (AnyDiscriminatorSqmPath<?>) r, leftExpressionContext );
right = r;
}
else {
left = l;
right = r;
}
}
SqmCriteriaNodeBuilder.assertComparable( left, right );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
import org.hibernate.internal.CoreLogging;
import org.hibernate.internal.CoreMessageLogger;
import org.hibernate.metamodel.mapping.EntityIdentifierMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.internal.SingleAttributeIdentifierMapping;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.metamodel.model.domain.EntityDomainType;
import org.hibernate.persister.entity.AbstractEntityPersister;
import org.hibernate.persister.entity.EntityPersister;
Expand Down Expand Up @@ -79,6 +81,7 @@
import org.hibernate.query.spi.SelectQueryPlan;
import org.hibernate.query.sqm.NodeBuilder;
import org.hibernate.query.sqm.SortOrder;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.SqmPathSource;
import org.hibernate.query.sqm.internal.SqmInterpretationsKey.InterpretationsKeySource;
import org.hibernate.query.sqm.mutation.spi.SqmMultiTableMutationStrategy;
Expand Down Expand Up @@ -113,6 +116,8 @@
import jakarta.persistence.Parameter;
import jakarta.persistence.PersistenceException;
import jakarta.persistence.TemporalType;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.JdbcType;

import static org.hibernate.jpa.HibernateHints.HINT_CACHEABLE;
import static org.hibernate.jpa.HibernateHints.HINT_CACHE_MODE;
Expand All @@ -130,6 +135,7 @@
import static org.hibernate.query.sqm.internal.SqmInterpretationsKey.generateNonSelectKey;
import static org.hibernate.query.sqm.internal.SqmUtil.isSelect;
import static org.hibernate.query.sqm.internal.SqmUtil.verifyIsNonSelectStatement;
import static org.hibernate.type.descriptor.java.JavaTypeHelper.isUnknown;

/**
* {@link Query} implementation based on an SQM
Expand Down Expand Up @@ -351,17 +357,13 @@ private void verifyUpdateTypesMatch(String hqlString, SqmUpdateStatement<R> sqmS
final SqmAssignment<?> assignment = assignments.get( i );
final SqmPath<?> targetPath = assignment.getTargetPath();
final SqmExpression<?> expression = assignment.getValue();
if ( targetPath.getNodeJavaType() == null || expression.getNodeJavaType() == null ) {
continue;
}
if ( targetPath.getNodeJavaType() != expression.getNodeJavaType()
&& !targetPath.getNodeJavaType().isWider( expression.getNodeJavaType() ) ) {
if ( !isAssignable( targetPath, expression ) ) {
throw new SemanticException(
String.format(
"The assignment expression type [%s] did not match the assignment path type [%s] for the path [%s]",
"Cannot assign expression of type '%s' to target path '%s' of type '%s'",
expression.getNodeJavaType().getJavaType().getTypeName(),
targetPath.getNodeJavaType().getJavaType().getTypeName(),
targetPath.toHqlString()
targetPath.toHqlString(),
targetPath.getNodeJavaType().getJavaType().getTypeName()
),
hqlString,
null
Expand All @@ -370,6 +372,34 @@ private void verifyUpdateTypesMatch(String hqlString, SqmUpdateStatement<R> sqmS
}
}

/**
* @see SqmCriteriaNodeBuilder#areTypesComparable(SqmExpressible, SqmExpressible)
*/
public static boolean isAssignable(SqmPath<?> targetPath, SqmExpression<?> expression) {
DomainType<?> lhsDomainType = targetPath.getExpressible().getSqmType();
DomainType<?> rhsDomainType = expression.getExpressible().getSqmType();
if ( lhsDomainType instanceof JdbcMapping && rhsDomainType instanceof JdbcMapping ) {
JdbcType lhsJdbcType = ((JdbcMapping) lhsDomainType).getJdbcType();
JdbcType rhsJdbcType = ((JdbcMapping) rhsDomainType).getJdbcType();
if ( lhsJdbcType.getJdbcTypeCode() == rhsJdbcType.getJdbcTypeCode()
|| lhsJdbcType.isString() && rhsJdbcType.isString()
|| lhsJdbcType.isInteger() && rhsJdbcType.isInteger() ) {
return true;
}
}

JavaType<?> targetType = targetPath.getNodeJavaType();
JavaType<?> assignedType = expression.getNodeJavaType();
return targetType == assignedType
// If we don't know the java types, let's just be lenient
|| isUnknown( targetType)
|| isUnknown( assignedType )
// Assume we can coerce one to another
|| targetType.isWider( assignedType )
// Enum assignment, other strange user type mappings
|| targetType.getJavaTypeClass().isAssignableFrom( assignedType.getJavaTypeClass() );
}

private void verifyInsertTypesMatch(String hqlString, SqmInsertStatement<R> sqmStatement) {
final List<SqmPath<?>> insertionTargetPaths = sqmStatement.getInsertionTargetPaths();
if ( sqmStatement instanceof SqmInsertValuesStatement<?> ) {
Expand Down

0 comments on commit 1a5b75f

Please sign in to comment.