From 73d3ac79bce7f192839f78e6cae72ae6c8bfd1f7 Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Tue, 5 Nov 2024 19:49:46 +0100 Subject: [PATCH 1/7] HHH-18793 Add JSON aggregate support for MySQL --- .../community/dialect/MySQLLegacyDialect.java | 14 +- .../process/spi/MetadataBuildingProcess.java | 10 + .../org/hibernate/dialect/JsonHelper.java | 1 + .../org/hibernate/dialect/MySQLDialect.java | 18 +- .../dialect/MySQLSqlAstTranslator.java | 18 +- .../aggregate/AggregateSupportImpl.java | 6 +- .../aggregate/MySQLAggregateSupport.java | 320 ++++++++++++++++++ .../aggregate/OracleAggregateSupport.java | 1 + .../aggregate/PostgreSQLAggregateSupport.java | 1 + .../internal/EmbeddableMappingTypeImpl.java | 6 +- 10 files changed, 384 insertions(+), 11 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MySQLLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MySQLLegacyDialect.java index ba4c01f6c3f3..21b07ca2e852 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MySQLLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MySQLLegacyDialect.java @@ -16,6 +16,8 @@ import org.hibernate.boot.model.TypeContributions; import org.hibernate.cfg.Environment; import org.hibernate.dialect.*; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.MySQLAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.identity.MySQLIdentityColumnSupport; @@ -263,7 +265,10 @@ protected String castType(int sqlTypeCode) { //MySQL doesn't let you cast to DOUBLE/FLOAT //but don't just return 'decimal' because //the default scale is 0 (no decimal places) - return "decimal($p,$s)"; + return getMySQLVersion().isSameOrAfter( 8, 0, 17 ) + // In newer versions of MySQL, casting to float/double is supported + ? super.castType( sqlTypeCode ) + : "decimal($p,$s)"; case CHAR: case NCHAR: case VARCHAR: @@ -385,6 +390,13 @@ protected void registerColumnTypes(TypeContributions typeContributions, ServiceR ddlTypeRegistry.addDescriptor( new NativeOrdinalEnumDdlTypeImpl( this ) ); } + @Override + public AggregateSupport getAggregateSupport() { + return getMySQLVersion().isSameOrAfter( 5, 7 ) + ? MySQLAggregateSupport.JSON_INSTANCE + : super.getAggregateSupport(); + } + @Deprecated protected static int getCharacterSetBytesPerCharacter(DatabaseMetaData databaseMetaData) { if ( databaseMetaData != null ) { diff --git a/hibernate-core/src/main/java/org/hibernate/boot/model/process/spi/MetadataBuildingProcess.java b/hibernate-core/src/main/java/org/hibernate/boot/model/process/spi/MetadataBuildingProcess.java index dfe8ac963574..dae2bb9137ee 100644 --- a/hibernate-core/src/main/java/org/hibernate/boot/model/process/spi/MetadataBuildingProcess.java +++ b/hibernate-core/src/main/java/org/hibernate/boot/model/process/spi/MetadataBuildingProcess.java @@ -83,6 +83,7 @@ import org.hibernate.type.descriptor.java.CharacterArrayJavaType; import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeConstructor; import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcTypeConstructor; import org.hibernate.type.descriptor.jdbc.JsonAsStringArrayJdbcTypeConstructor; import org.hibernate.type.descriptor.jdbc.JsonAsStringJdbcType; @@ -101,6 +102,7 @@ import jakarta.persistence.AttributeConverter; import static org.hibernate.internal.util.collections.CollectionHelper.mutableJoin; +import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForArray; import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForDuration; import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForInstant; import static org.hibernate.internal.util.config.ConfigurationHelper.getPreferredSqlTypeCodeForUuid; @@ -771,6 +773,14 @@ public void contributeType(CompositeUserType type) { jdbcTypeRegistry.addTypeConstructor( XmlAsStringArrayJdbcTypeConstructor.INSTANCE ); } } + if ( jdbcTypeRegistry.getConstructor( SqlTypes.ARRAY ) == null ) { + // Default the array constructor to e.g. JSON_ARRAY/XML_ARRAY if needed + final JdbcTypeConstructor constructor = + jdbcTypeRegistry.getConstructor( getPreferredSqlTypeCodeForArray( serviceRegistry ) ); + if ( constructor != null ) { + jdbcTypeRegistry.addTypeConstructor( SqlTypes.ARRAY, constructor ); + } + } final int preferredSqlTypeCodeForDuration = getPreferredSqlTypeCodeForDuration( serviceRegistry ); if ( preferredSqlTypeCodeForDuration != SqlTypes.INTERVAL_SECOND ) { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/JsonHelper.java b/hibernate-core/src/main/java/org/hibernate/dialect/JsonHelper.java index 3685fba17948..c06ab3b1eeec 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/JsonHelper.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/JsonHelper.java @@ -322,6 +322,7 @@ private static void convertedBasicValueToString( appender.append( '"' ); break; case SqlTypes.ARRAY: + case SqlTypes.JSON_ARRAY: final int length = Array.getLength( value ); appender.append( '[' ); if ( length != 0 ) { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java index 93ad2d709f59..61a362e3c89c 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java @@ -21,6 +21,8 @@ import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; import org.hibernate.cfg.Environment; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.MySQLAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.identity.MySQLIdentityColumnSupport; @@ -316,12 +318,15 @@ protected String castType(int sqlTypeCode) { // MySQL doesn't let you cast to DOUBLE/FLOAT // but don't just return 'decimal' because // the default scale is 0 (no decimal places) - case FLOAT, REAL, DOUBLE -> "decimal($p,$s)"; + case FLOAT, REAL, DOUBLE -> getMySQLVersion().isSameOrAfter( 8, 0, 17 ) + // In newer versions of MySQL, casting to float/double is supported + ? super.castType( sqlTypeCode ) + : "decimal($p,$s)"; // MySQL doesn't let you cast to TEXT/LONGTEXT - case CHAR, VARCHAR, LONG32VARCHAR -> "char"; - case NCHAR, NVARCHAR, LONG32NVARCHAR -> "char character set utf8mb4"; + case CHAR, VARCHAR, LONG32VARCHAR, CLOB -> "char"; + case NCHAR, NVARCHAR, LONG32NVARCHAR, NCLOB -> "char character set utf8mb4"; // MySQL doesn't let you cast to BLOB/TINYBLOB/LONGBLOB - case BINARY, VARBINARY, LONG32VARBINARY -> "binary"; + case BINARY, VARBINARY, LONG32VARBINARY, BLOB -> "binary"; default -> super.castType(sqlTypeCode); }; } @@ -433,6 +438,11 @@ protected void registerColumnTypes(TypeContributions typeContributions, ServiceR ddlTypeRegistry.addDescriptor( new NativeOrdinalEnumDdlTypeImpl( this ) ); } + @Override + public AggregateSupport getAggregateSupport() { + return MySQLAggregateSupport.valueOf( this ); + } + @Deprecated(since="6.4") protected static int getCharacterSetBytesPerCharacter(DatabaseMetaData databaseMetaData) { if ( databaseMetaData != null ) { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLSqlAstTranslator.java index b672914dfaf4..1da63038e605 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLSqlAstTranslator.java @@ -48,6 +48,11 @@ */ public class MySQLSqlAstTranslator extends AbstractSqlAstTranslator { + /** + * On MySQL, 1GB or {@code 2^30 - 1} is the maximum size that a char value can be casted. + */ + private static final int MAX_CHAR_SIZE = (1 << 30) - 1; + public MySQLSqlAstTranslator(SessionFactoryImplementor sessionFactory, Statement statement) { super( sessionFactory, statement ); } @@ -64,7 +69,7 @@ public static String getSqlType(CastTarget castTarget, SessionFactoryImplementor private static String getSqlType(CastTarget castTarget, String sqlType, Dialect dialect) { if ( sqlType != null ) { int parenthesesIndex = sqlType.indexOf( '(' ); - final String baseName = parenthesesIndex == -1 ? sqlType : sqlType.substring( 0, parenthesesIndex ); + final String baseName = parenthesesIndex == -1 ? sqlType : sqlType.substring( 0, parenthesesIndex ).trim(); switch ( baseName.toLowerCase( Locale.ROOT ) ) { case "bit": return "unsigned"; @@ -76,6 +81,9 @@ private static String getSqlType(CastTarget castTarget, String sqlType, Dialect case "float": case "real": case "double precision": + if ( ((MySQLDialect) dialect).getMySQLVersion().isSameOrAfter( 8, 0, 17 ) ) { + return sqlType; + } final int precision = castTarget.getPrecision() == null ? dialect.getDefaultDecimalPrecision() : castTarget.getPrecision(); @@ -85,6 +93,10 @@ private static String getSqlType(CastTarget castTarget, String sqlType, Dialect case "varchar": case "nchar": case "nvarchar": + case "text": + case "mediumtext": + case "longtext": + case "enum": if ( castTarget.getLength() == null ) { // TODO: this is ugly and fragile, but could easily be handled in a DdlType if ( castTarget.getJdbcMapping().getJdbcJavaType().getJavaType() == Character.class ) { @@ -94,9 +106,11 @@ private static String getSqlType(CastTarget castTarget, String sqlType, Dialect return "char"; } } - return "char(" + castTarget.getLength() + ")"; + return castTarget.getLength() > MAX_CHAR_SIZE ? "char" : "char(" + castTarget.getLength() + ")"; case "binary": case "varbinary": + case "mediumblob": + case "longblob": return castTarget.getLength() == null ? "binary" : "binary(" + castTarget.getLength() + ")"; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupportImpl.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupportImpl.java index 4f1c78887e85..0fa815973239 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupportImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupportImpl.java @@ -13,6 +13,7 @@ import org.hibernate.mapping.Column; import org.hibernate.metamodel.mapping.SelectableMapping; import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.type.SqlTypes; import org.hibernate.type.spi.TypeConfiguration; public class AggregateSupportImpl implements AggregateSupport { @@ -76,7 +77,10 @@ public List aggregateAuxiliaryDatabaseObjects( @Override public int aggregateComponentSqlTypeCode(int aggregateColumnSqlTypeCode, int columnSqlTypeCode) { - return columnSqlTypeCode; + return switch (aggregateColumnSqlTypeCode) { + case SqlTypes.JSON -> columnSqlTypeCode == SqlTypes.ARRAY ? SqlTypes.JSON_ARRAY : columnSqlTypeCode; + default -> columnSqlTypeCode; + }; } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java new file mode 100644 index 000000000000..920a28238539 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java @@ -0,0 +1,320 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.aggregate; + +import org.hibernate.dialect.Dialect; +import org.hibernate.internal.util.StringHelper; +import org.hibernate.mapping.Column; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SelectableMapping; +import org.hibernate.metamodel.mapping.SelectablePath; +import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.sql.ast.SqlAstNodeRenderingMode; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.hibernate.type.SqlTypes.BIGINT; +import static org.hibernate.type.SqlTypes.BINARY; +import static org.hibernate.type.SqlTypes.BIT; +import static org.hibernate.type.SqlTypes.BLOB; +import static org.hibernate.type.SqlTypes.BOOLEAN; +import static org.hibernate.type.SqlTypes.CHAR; +import static org.hibernate.type.SqlTypes.CLOB; +import static org.hibernate.type.SqlTypes.ENUM; +import static org.hibernate.type.SqlTypes.INTEGER; +import static org.hibernate.type.SqlTypes.JSON; +import static org.hibernate.type.SqlTypes.JSON_ARRAY; +import static org.hibernate.type.SqlTypes.LONG32NVARCHAR; +import static org.hibernate.type.SqlTypes.LONG32VARBINARY; +import static org.hibernate.type.SqlTypes.LONG32VARCHAR; +import static org.hibernate.type.SqlTypes.NCHAR; +import static org.hibernate.type.SqlTypes.NCLOB; +import static org.hibernate.type.SqlTypes.NVARCHAR; +import static org.hibernate.type.SqlTypes.SMALLINT; +import static org.hibernate.type.SqlTypes.TIMESTAMP; +import static org.hibernate.type.SqlTypes.TIMESTAMP_UTC; +import static org.hibernate.type.SqlTypes.TINYINT; +import static org.hibernate.type.SqlTypes.VARBINARY; +import static org.hibernate.type.SqlTypes.VARCHAR; + +public class MySQLAggregateSupport extends AggregateSupportImpl { + + private static final AggregateSupport INSTANCE = new MySQLAggregateSupport(); + + public static AggregateSupport valueOf(Dialect dialect) { + return MySQLAggregateSupport.INSTANCE; + } + + @Override + public String aggregateComponentCustomReadExpression( + String template, + String placeholder, + String aggregateParentReadExpression, + String columnExpression, + int aggregateColumnTypeCode, + SqlTypedMapping column) { + switch ( aggregateColumnTypeCode ) { + case JSON_ARRAY: + case JSON: + switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case JSON: + case JSON_ARRAY: + return template.replace( + placeholder, + queryExpression( aggregateParentReadExpression, columnExpression ) + ); + case BOOLEAN: + return template.replace( + placeholder, + "case " + queryExpression( aggregateParentReadExpression, columnExpression ) + " when 'true' then true when 'false' then false end" + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + // We encode binary data as hex, so we have to decode here + return template.replace( + placeholder, + "unhex(json_unquote(" + queryExpression( aggregateParentReadExpression, columnExpression ) + "))" + ); + default: + return template.replace( + placeholder, + valueExpression( aggregateParentReadExpression, columnExpression, columnCastType( column ) ) + ); + } + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private static String columnCastType(SqlTypedMapping column) { + return switch (column.getJdbcMapping().getJdbcType().getDdlTypeCode()) { + // special case for casting to Boolean + case BOOLEAN, BIT -> "unsigned"; + // MySQL doesn't let you cast to INTEGER/BIGINT/TINYINT + case TINYINT, SMALLINT, INTEGER, BIGINT -> "signed"; + // MySQL doesn't let you cast to TEXT/LONGTEXT + case CHAR, VARCHAR, LONG32VARCHAR, CLOB, ENUM -> "char"; + case NCHAR, NVARCHAR, LONG32NVARCHAR, NCLOB -> "char character set utf8mb4"; + // MySQL doesn't let you cast to BLOB/TINYBLOB/LONGBLOB + case BINARY, VARBINARY, LONG32VARBINARY, BLOB -> "binary"; + default -> column.getColumnDefinition(); + }; + } + + private static String valueExpression(String aggregateParentReadExpression, String columnExpression, String columnType) { + return "cast(json_unquote(" + queryExpression( aggregateParentReadExpression, columnExpression ) + ") as " + columnType + ')'; + } + + private static String queryExpression(String aggregateParentReadExpression, String columnExpression) { + return "nullif(json_extract(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),cast('null' as json))"; + } + + private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) { + final int sqlTypeCode = jdbcMapping.getJdbcType().getDefaultSqlTypeCode(); + switch ( sqlTypeCode ) { + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex + return "hex(" + customWriteExpression + ")"; + case BOOLEAN: + return "(" + customWriteExpression + ")=true"; + case TIMESTAMP: + return "date_format(" + customWriteExpression + ",'%Y-%m-%dT%T.%f')"; + case TIMESTAMP_UTC: + return "date_format(" + customWriteExpression + ",'%Y-%m-%dT%T.%fZ')"; + default: + return customWriteExpression; + } + } + + @Override + public int aggregateComponentSqlTypeCode(int aggregateColumnSqlTypeCode, int columnSqlTypeCode) { + return super.aggregateComponentSqlTypeCode( aggregateColumnSqlTypeCode, columnSqlTypeCode ); + } + + @Override + public String aggregateComponentAssignmentExpression( + String aggregateParentAssignmentExpression, + String columnExpression, + int aggregateColumnTypeCode, + Column column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + // For JSON we always have to replace the whole object + return aggregateParentAssignmentExpression; + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + @Override + public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { + switch ( aggregateSqlTypeCode ) { + case JSON: + return true; + } + return false; + } + + @Override + public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer( + SelectableMapping aggregateColumn, + SelectableMapping[] columnsToUpdate, + TypeConfiguration typeConfiguration) { + final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + switch ( aggregateSqlTypeCode ) { + case JSON: + return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate ); + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode ); + } + + private WriteExpressionRenderer jsonAggregateColumnWriter( + SelectableMapping aggregateColumn, + SelectableMapping[] columns) { + return new RootJsonWriteExpression( aggregateColumn, columns ); + } + + interface JsonWriteExpression { + void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression); + } + private static class AggregateJsonWriteExpression implements JsonWriteExpression { + private final LinkedHashMap subExpressions = new LinkedHashMap<>(); + + protected void initializeSubExpressions(SelectableMapping[] columns) { + for ( SelectableMapping column : columns ) { + final SelectablePath selectablePath = column.getSelectablePath(); + final SelectablePath[] parts = selectablePath.getParts(); + AggregateJsonWriteExpression currentAggregate = this; + for ( int i = 1; i < parts.length - 1; i++ ) { + currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent( + parts[i].getSelectableName(), + k -> new AggregateJsonWriteExpression() + ); + } + final String customWriteExpression = column.getWriteExpression(); + currentAggregate.subExpressions.put( + parts[parts.length - 1].getSelectableName(), + new BasicJsonWriteExpression( + column, + jsonCustomWriteExpression( customWriteExpression, column.getJdbcMapping() ) + ) + ); + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + final String subPath = queryExpression( path, column ); + sb.append( ',' ); + if ( value instanceof AggregateJsonWriteExpression ) { + sb.append( "'$." ); + sb.append( column ); + sb.append( "',json_set(coalesce(" ); + sb.append( subPath ); + sb.append( ",json_object())" ); + value.append( sb, subPath, translator, expression ); + sb.append( ')' ); + } + else { + value.append( sb, subPath, translator, expression ); + } + } + } + } + + private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + implements WriteExpressionRenderer { + private final boolean nullable; + private final String path; + + RootJsonWriteExpression(SelectableMapping aggregateColumn, SelectableMapping[] columns) { + this.nullable = aggregateColumn.isNullable(); + this.path = aggregateColumn.getSelectionExpression(); + initializeSubExpressions( columns ); + } + + @Override + public void render( + SqlAppender sqlAppender, + SqlAstTranslator translator, + AggregateColumnWriteExpression aggregateColumnWriteExpression, + String qualifier) { + final String basePath; + if ( qualifier == null || qualifier.isBlank() ) { + basePath = path; + } + else { + basePath = qualifier + "." + path; + } + sqlAppender.appendSql( "json_set(" ); + if ( nullable ) { + sqlAppender.append( "coalesce(" ); + sqlAppender.append( basePath ); + sqlAppender.append( ",json_object())" ); + } + else { + sqlAppender.append( basePath ); + } + append( sqlAppender, basePath, translator, aggregateColumnWriteExpression ); + sqlAppender.append( ')' ); + } + } + private static class BasicJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String customWriteExpressionStart; + private final String customWriteExpressionEnd; + + BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) { + this.selectableMapping = selectableMapping; + if ( customWriteExpression.equals( "?" ) ) { + this.customWriteExpressionStart = ""; + this.customWriteExpressionEnd = ""; + } + else { + final String[] parts = StringHelper.split( "?", customWriteExpression ); + assert parts.length == 2; + this.customWriteExpressionStart = parts[0]; + this.customWriteExpressionEnd = parts[1]; + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( "'$." ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "'," ); + sb.append( customWriteExpressionStart ); + // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, + // since we don't know how the custom write expression looks like where this is embedded, + // so we have to be pessimistic and avoid ambiguities + translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + sb.append( customWriteExpressionEnd ); + } + } + +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java index 6564dea9b97d..172b1511bad2 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java @@ -218,6 +218,7 @@ public String aggregateComponentCustomReadExpression( ); } case JSON: + case JSON_ARRAY: return template.replace( placeholder, "json_query(" + parentPartExpression + columnExpression + "' returning " + jsonTypeName + ")" diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/PostgreSQLAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/PostgreSQLAggregateSupport.java index f177fc62d72e..3b6cc39553ed 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/PostgreSQLAggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/PostgreSQLAggregateSupport.java @@ -58,6 +58,7 @@ public String aggregateComponentCustomReadExpression( case JSON: switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { case JSON: + case JSON_ARRAY: return template.replace( placeholder, aggregateParentReadExpression + "->'" + columnExpression + "'" diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java index 75cf134c9449..ae07f819407f 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java @@ -65,8 +65,8 @@ import org.hibernate.type.CollectionType; import org.hibernate.type.CompositeType; import org.hibernate.type.EntityType; -import org.hibernate.type.SqlTypes; import org.hibernate.type.Type; +import org.hibernate.type.descriptor.JdbcTypeNameMapper; import org.hibernate.type.descriptor.java.BasicPluralJavaType; import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan; import org.hibernate.type.descriptor.java.JavaType; @@ -319,9 +319,9 @@ private JdbcMapping resolveJdbcMapping(Component bootDescriptor, RuntimeModelCre } final BasicType resolvedJdbcMapping; if ( isArray ) { - final JdbcTypeConstructor arrayConstructor = jdbcTypeRegistry.getConstructor( SqlTypes.ARRAY ); + final JdbcTypeConstructor arrayConstructor = jdbcTypeRegistry.getConstructor( aggregateColumnSqlTypeCode ); if ( arrayConstructor == null ) { - throw new IllegalArgumentException( "No JdbcTypeConstructor registered for SqlTypes.ARRAY" ); + throw new IllegalArgumentException( "No JdbcTypeConstructor registered for SqlTypes." + JdbcTypeNameMapper.getTypeName( aggregateColumnSqlTypeCode ) ); } //noinspection rawtypes,unchecked final BasicType arrayType = ( (BasicPluralJavaType) resolution.getDomainJavaType() ).resolveType( From 10c8582a647560c4ed507b02e1e9142d7971a5e2 Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Tue, 5 Nov 2024 20:58:36 +0100 Subject: [PATCH 2/7] HHH-18794 Add JSON aggregate support for MariaDB --- .../dialect/MariaDBLegacyDialect.java | 24 +++++++-- .../MariaDBLegacySqlAstTranslator.java | 50 ++++++++++++++++++- .../MariaDBCastingJsonArrayJdbcType.java | 29 +++++++++++ ...DBCastingJsonArrayJdbcTypeConstructor.java | 43 ++++++++++++++++ .../dialect/MariaDBCastingJsonJdbcType.java | 43 ++++++++++++++++ .../org/hibernate/dialect/MariaDBDialect.java | 23 +++++++-- .../dialect/MariaDBSqlAstTranslator.java | 50 ++++++++++++++++++- .../org/hibernate/dialect/MySQLDialect.java | 2 +- .../aggregate/MySQLAggregateSupport.java | 40 +++++++++++---- .../org/hibernate/query/sqm/CastType.java | 1 + .../ast/tree/expression/ColumnReference.java | 34 ++++++------- .../sql/ast/tree/expression/Expression.java | 5 +- .../type/descriptor/jdbc/JdbcType.java | 3 ++ .../orm/test/function/json/JsonTableTest.java | 44 ++++++++++++++++ .../orm/test/type/BasicListTest.java | 2 + .../orm/test/type/BasicSortedSetTest.java | 2 + .../orm/test/type/BooleanArrayTest.java | 2 + .../orm/test/type/DateArrayTest.java | 2 + .../orm/test/type/DoubleArrayTest.java | 2 + .../orm/test/type/EnumArrayTest.java | 2 + .../orm/test/type/EnumSetConverterTest.java | 2 + .../hibernate/orm/test/type/EnumSetTest.java | 2 + .../orm/test/type/FloatArrayTest.java | 2 + .../orm/test/type/IntegerArrayTest.java | 2 + .../orm/test/type/LongArrayTest.java | 2 + .../orm/test/type/ShortArrayTest.java | 2 + .../orm/test/type/StringArrayTest.java | 2 + .../orm/test/type/TimeArrayTest.java | 2 + .../orm/test/type/TimestampArrayTest.java | 2 + 29 files changed, 379 insertions(+), 42 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcType.java create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcTypeConstructor.java create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonJdbcType.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacyDialect.java index 564f2ad33488..21e177efc6b2 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacyDialect.java @@ -10,6 +10,9 @@ import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.*; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.AggregateSupportImpl; +import org.hibernate.dialect.aggregate.MySQLAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.sequence.MariaDBSequenceSupport; import org.hibernate.dialect.sequence.SequenceSupport; @@ -18,6 +21,7 @@ import org.hibernate.engine.jdbc.env.spi.IdentifierHelper; import org.hibernate.engine.jdbc.env.spi.IdentifierHelperBuilder; import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.query.sqm.CastType; import org.hibernate.service.ServiceRegistry; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.SqlAstTranslatorFactory; @@ -29,8 +33,6 @@ import org.hibernate.type.SqlTypes; import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.descriptor.jdbc.JdbcType; -import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcTypeConstructor; -import org.hibernate.type.descriptor.jdbc.JsonJdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; @@ -122,6 +124,13 @@ protected void registerColumnTypes(TypeContributions typeContributions, ServiceR } } + @Override + public AggregateSupport getAggregateSupport() { + return getVersion().isSameOrAfter( 10, 2 ) + ? MySQLAggregateSupport.LONGTEXT_INSTANCE + : AggregateSupportImpl.INSTANCE; + } + @Override public JdbcType resolveSqlTypeDescriptor( String columnTypeName, @@ -150,8 +159,8 @@ public JdbcType resolveSqlTypeDescriptor( public void contributeTypes(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { final JdbcTypeRegistry jdbcTypeRegistry = typeContributions.getTypeConfiguration().getJdbcTypeRegistry(); // Make sure we register the JSON type descriptor before calling super, because MariaDB does not need casting - jdbcTypeRegistry.addDescriptorIfAbsent( SqlTypes.JSON, JsonJdbcType.INSTANCE ); - jdbcTypeRegistry.addTypeConstructorIfAbsent( JsonArrayJdbcTypeConstructor.INSTANCE ); + jdbcTypeRegistry.addDescriptorIfAbsent( SqlTypes.JSON, MariaDBCastingJsonJdbcType.INSTANCE ); + jdbcTypeRegistry.addTypeConstructorIfAbsent( MariaDBCastingJsonArrayJdbcTypeConstructor.INSTANCE ); super.contributeTypes( typeContributions, serviceRegistry ); if ( getVersion().isSameOrAfter( 10, 7 ) ) { @@ -159,6 +168,13 @@ public void contributeTypes(TypeContributions typeContributions, ServiceRegistry } } + @Override + public String castPattern(CastType from, CastType to) { + return to == CastType.JSON + ? "json_extract(?1,'$')" + : super.castPattern( from, to ); + } + @Override public SqlAstTranslatorFactory getSqlAstTranslatorFactory() { return new StandardSqlAstTranslatorFactory() { diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacySqlAstTranslator.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacySqlAstTranslator.java index fb021d624f49..06c10143b303 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacySqlAstTranslator.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/MariaDBLegacySqlAstTranslator.java @@ -11,6 +11,7 @@ import org.hibernate.dialect.MySQLSqlAstTranslator; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.internal.util.collections.Stack; +import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.sqm.ComparisonOperator; import org.hibernate.sql.ast.Clause; import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; @@ -292,7 +293,54 @@ public void visitOffsetFetchClause(QueryPart queryPart) { @Override protected void renderComparison(Expression lhs, ComparisonOperator operator, Expression rhs) { - renderComparisonDistinctOperator( lhs, operator, rhs ); + final JdbcMappingContainer lhsExpressionType = lhs.getExpressionType(); + if ( lhsExpressionType != null && lhsExpressionType.getJdbcTypeCount() == 1 + && lhsExpressionType.getSingleJdbcMapping().getJdbcType().isJson() ) { + switch ( operator ) { + case DISTINCT_FROM: + appendSql( "case when json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=1 or " ); + lhs.accept( this ); + appendSql( " is null and " ); + rhs.accept( this ); + appendSql( " is null then 0 else 1 end=1" ); + break; + case NOT_DISTINCT_FROM: + appendSql( "case when json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=1 or " ); + lhs.accept( this ); + appendSql( " is null and " ); + rhs.accept( this ); + appendSql( " is null then 0 else 1 end=0" ); + break; + case NOT_EQUAL: + appendSql( "json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=0" ); + break; + case EQUAL: + appendSql( "json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=1" ); + break; + default: + renderComparisonDistinctOperator( lhs, operator, rhs ); + break; + } + } + else { + renderComparisonDistinctOperator( lhs, operator, rhs ); + } } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcType.java new file mode 100644 index 000000000000..712da2bdcc98 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcType.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect; + +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcType; + +/** + * @author Christian Beikov + */ +public class MariaDBCastingJsonArrayJdbcType extends JsonArrayJdbcType { + + public MariaDBCastingJsonArrayJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public void appendWriteExpression( + String writeExpression, + SqlAppender appender, + Dialect dialect) { + appender.append( "json_extract(" ); + appender.append( writeExpression ); + appender.append( ",'$')" ); + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcTypeConstructor.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcTypeConstructor.java new file mode 100644 index 000000000000..003792b8ba9b --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonArrayJdbcTypeConstructor.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect; + +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeConstructor; +import org.hibernate.type.spi.TypeConfiguration; + +/** + * Factory for {@link MariaDBCastingJsonArrayJdbcType}. + */ +public class MariaDBCastingJsonArrayJdbcTypeConstructor implements JdbcTypeConstructor { + + public static final MariaDBCastingJsonArrayJdbcTypeConstructor INSTANCE = new MariaDBCastingJsonArrayJdbcTypeConstructor(); + + @Override + public JdbcType resolveType( + TypeConfiguration typeConfiguration, + Dialect dialect, + BasicType elementType, + ColumnTypeInformation columnTypeInformation) { + return resolveType( typeConfiguration, dialect, elementType.getJdbcType(), columnTypeInformation ); + } + + @Override + public JdbcType resolveType( + TypeConfiguration typeConfiguration, + Dialect dialect, + JdbcType elementType, + ColumnTypeInformation columnTypeInformation) { + return new MariaDBCastingJsonArrayJdbcType( elementType ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.JSON_ARRAY; + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonJdbcType.java new file mode 100644 index 000000000000..28f616a732c4 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBCastingJsonJdbcType.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect; + +import org.hibernate.metamodel.mapping.EmbeddableMappingType; +import org.hibernate.metamodel.spi.RuntimeModelCreationContext; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.jdbc.AggregateJdbcType; +import org.hibernate.type.descriptor.jdbc.JsonJdbcType; + +/** + * @author Christian Beikov + */ +public class MariaDBCastingJsonJdbcType extends JsonJdbcType { + /** + * Singleton access + */ + public static final JsonJdbcType INSTANCE = new MariaDBCastingJsonJdbcType( null ); + + public MariaDBCastingJsonJdbcType(EmbeddableMappingType embeddableMappingType) { + super( embeddableMappingType ); + } + + @Override + public AggregateJdbcType resolveAggregateJdbcType( + EmbeddableMappingType mappingType, + String sqlType, + RuntimeModelCreationContext creationContext) { + return new MariaDBCastingJsonJdbcType( mappingType ); + } + + @Override + public void appendWriteExpression( + String writeExpression, + SqlAppender appender, + Dialect dialect) { + appender.append( "json_extract(" ); + appender.append( writeExpression ); + appender.append( ",'$')" ); + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java index 5aadd93192d5..938f893bfbcc 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java @@ -9,6 +9,8 @@ import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.MySQLAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.identity.IdentityColumnSupport; import org.hibernate.dialect.identity.MariaDBIdentityColumnSupport; @@ -19,6 +21,7 @@ import org.hibernate.engine.jdbc.env.spi.IdentifierHelper; import org.hibernate.engine.jdbc.env.spi.IdentifierHelperBuilder; import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.query.sqm.CastType; import org.hibernate.service.ServiceRegistry; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.SqlAstTranslatorFactory; @@ -30,8 +33,6 @@ import org.hibernate.type.SqlTypes; import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.descriptor.jdbc.JdbcType; -import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcTypeConstructor; -import org.hibernate.type.descriptor.jdbc.JsonJdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; @@ -122,6 +123,11 @@ protected void registerColumnTypes(TypeContributions typeContributions, ServiceR } } + @Override + public AggregateSupport getAggregateSupport() { + return MySQLAggregateSupport.LONGTEXT_INSTANCE; + } + @Override protected void registerKeyword(String word) { // The MariaDB driver reports that "STRING" is a keyword, but @@ -156,9 +162,9 @@ public JdbcType resolveSqlTypeDescriptor( @Override public void contributeTypes(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { final JdbcTypeRegistry jdbcTypeRegistry = typeContributions.getTypeConfiguration().getJdbcTypeRegistry(); - // Make sure we register the JSON type descriptor before calling super, because MariaDB does not need casting - jdbcTypeRegistry.addDescriptorIfAbsent( SqlTypes.JSON, JsonJdbcType.INSTANCE ); - jdbcTypeRegistry.addTypeConstructorIfAbsent( JsonArrayJdbcTypeConstructor.INSTANCE ); + // Make sure we register the JSON type descriptor before calling super, because MariaDB needs special casting + jdbcTypeRegistry.addDescriptorIfAbsent( SqlTypes.JSON, MariaDBCastingJsonJdbcType.INSTANCE ); + jdbcTypeRegistry.addTypeConstructorIfAbsent( MariaDBCastingJsonArrayJdbcTypeConstructor.INSTANCE ); super.contributeTypes( typeContributions, serviceRegistry ); if ( getVersion().isSameOrAfter( 10, 7 ) ) { @@ -166,6 +172,13 @@ public void contributeTypes(TypeContributions typeContributions, ServiceRegistry } } + @Override + public String castPattern(CastType from, CastType to) { + return to == CastType.JSON + ? "json_extract(?1,'$')" + : super.castPattern( from, to ); + } + @Override public SqlAstTranslatorFactory getSqlAstTranslatorFactory() { return new StandardSqlAstTranslatorFactory() { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBSqlAstTranslator.java index ea540352d8c5..e6ca4499af5c 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBSqlAstTranslator.java @@ -9,6 +9,7 @@ import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.internal.util.collections.Stack; +import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.sqm.ComparisonOperator; import org.hibernate.sql.ast.Clause; import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; @@ -295,7 +296,54 @@ public void visitOffsetFetchClause(QueryPart queryPart) { @Override protected void renderComparison(Expression lhs, ComparisonOperator operator, Expression rhs) { - renderComparisonDistinctOperator( lhs, operator, rhs ); + final JdbcMappingContainer lhsExpressionType = lhs.getExpressionType(); + if ( lhsExpressionType != null && lhsExpressionType.getJdbcTypeCount() == 1 + && lhsExpressionType.getSingleJdbcMapping().getJdbcType().isJson() ) { + switch ( operator ) { + case DISTINCT_FROM: + appendSql( "case when json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=1 or " ); + lhs.accept( this ); + appendSql( " is null and " ); + rhs.accept( this ); + appendSql( " is null then 0 else 1 end=1" ); + break; + case NOT_DISTINCT_FROM: + appendSql( "case when json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=1 or " ); + lhs.accept( this ); + appendSql( " is null and " ); + rhs.accept( this ); + appendSql( " is null then 0 else 1 end=0" ); + break; + case NOT_EQUAL: + appendSql( "json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=0" ); + break; + case EQUAL: + appendSql( "json_equals(" ); + lhs.accept( this ); + appendSql( ',' ); + rhs.accept( this ); + appendSql( ")=1" ); + break; + default: + renderComparisonDistinctOperator( lhs, operator, rhs ); + break; + } + } + else { + renderComparisonDistinctOperator( lhs, operator, rhs ); + } } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java index 61a362e3c89c..b7003b2729a4 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java @@ -440,7 +440,7 @@ protected void registerColumnTypes(TypeContributions typeContributions, ServiceR @Override public AggregateSupport getAggregateSupport() { - return MySQLAggregateSupport.valueOf( this ); + return MySQLAggregateSupport.JSON_INSTANCE; } @Deprecated(since="6.4") diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java index 920a28238539..9e0fe03d2ac1 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/MySQLAggregateSupport.java @@ -4,7 +4,6 @@ */ package org.hibernate.dialect.aggregate; -import org.hibernate.dialect.Dialect; import org.hibernate.internal.util.StringHelper; import org.hibernate.mapping.Column; import org.hibernate.metamodel.mapping.JdbcMapping; @@ -26,7 +25,9 @@ import static org.hibernate.type.SqlTypes.BOOLEAN; import static org.hibernate.type.SqlTypes.CHAR; import static org.hibernate.type.SqlTypes.CLOB; +import static org.hibernate.type.SqlTypes.DOUBLE; import static org.hibernate.type.SqlTypes.ENUM; +import static org.hibernate.type.SqlTypes.FLOAT; import static org.hibernate.type.SqlTypes.INTEGER; import static org.hibernate.type.SqlTypes.JSON; import static org.hibernate.type.SqlTypes.JSON_ARRAY; @@ -36,6 +37,7 @@ import static org.hibernate.type.SqlTypes.NCHAR; import static org.hibernate.type.SqlTypes.NCLOB; import static org.hibernate.type.SqlTypes.NVARCHAR; +import static org.hibernate.type.SqlTypes.REAL; import static org.hibernate.type.SqlTypes.SMALLINT; import static org.hibernate.type.SqlTypes.TIMESTAMP; import static org.hibernate.type.SqlTypes.TIMESTAMP_UTC; @@ -45,10 +47,13 @@ public class MySQLAggregateSupport extends AggregateSupportImpl { - private static final AggregateSupport INSTANCE = new MySQLAggregateSupport(); + public static final AggregateSupport JSON_INSTANCE = new MySQLAggregateSupport( true ); + public static final AggregateSupport LONGTEXT_INSTANCE = new MySQLAggregateSupport( false ); - public static AggregateSupport valueOf(Dialect dialect) { - return MySQLAggregateSupport.INSTANCE; + private final boolean jsonType; + + public MySQLAggregateSupport(boolean jsonType) { + this.jsonType = jsonType; } @Override @@ -72,7 +77,9 @@ public String aggregateComponentCustomReadExpression( case BOOLEAN: return template.replace( placeholder, - "case " + queryExpression( aggregateParentReadExpression, columnExpression ) + " when 'true' then true when 'false' then false end" + jsonType + ? "case " + queryExpression( aggregateParentReadExpression, columnExpression ) + " when cast('true' as json) then true when cast('false' as json) then false end" + : "case " + queryExpression( aggregateParentReadExpression, columnExpression ) + " when 'true' then true when 'false' then false end" ); case BINARY: case VARBINARY: @@ -92,12 +99,18 @@ public String aggregateComponentCustomReadExpression( throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); } - private static String columnCastType(SqlTypedMapping column) { + private String columnCastType(SqlTypedMapping column) { return switch (column.getJdbcMapping().getJdbcType().getDdlTypeCode()) { // special case for casting to Boolean case BOOLEAN, BIT -> "unsigned"; // MySQL doesn't let you cast to INTEGER/BIGINT/TINYINT case TINYINT, SMALLINT, INTEGER, BIGINT -> "signed"; + case REAL -> "float"; + case DOUBLE -> "double"; + case FLOAT -> jsonType + // In newer versions of MySQL, casting to float/double is supported + ? column.getColumnDefinition() + : column.getPrecision() == null || column.getPrecision() == 53 ? "double" : "float"; // MySQL doesn't let you cast to TEXT/LONGTEXT case CHAR, VARCHAR, LONG32VARCHAR, CLOB, ENUM -> "char"; case NCHAR, NVARCHAR, LONG32NVARCHAR, NCLOB -> "char character set utf8mb4"; @@ -107,12 +120,17 @@ private static String columnCastType(SqlTypedMapping column) { }; } - private static String valueExpression(String aggregateParentReadExpression, String columnExpression, String columnType) { + private String valueExpression(String aggregateParentReadExpression, String columnExpression, String columnType) { return "cast(json_unquote(" + queryExpression( aggregateParentReadExpression, columnExpression ) + ") as " + columnType + ')'; } - private static String queryExpression(String aggregateParentReadExpression, String columnExpression) { - return "nullif(json_extract(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),cast('null' as json))"; + private String queryExpression(String aggregateParentReadExpression, String columnExpression) { + if ( jsonType ) { + return "nullif(json_extract(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),cast('null' as json))"; + } + else { + return "nullif(json_extract(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),'null')"; + } } private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) { @@ -190,7 +208,7 @@ void append( SqlAstTranslator translator, AggregateColumnWriteExpression expression); } - private static class AggregateJsonWriteExpression implements JsonWriteExpression { + private class AggregateJsonWriteExpression implements JsonWriteExpression { private final LinkedHashMap subExpressions = new LinkedHashMap<>(); protected void initializeSubExpressions(SelectableMapping[] columns) { @@ -242,7 +260,7 @@ public void append( } } - private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + private class RootJsonWriteExpression extends AggregateJsonWriteExpression implements WriteExpressionRenderer { private final boolean nullable; private final String path; diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/CastType.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/CastType.java index da57813bec19..7c5babc93ade 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/CastType.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/CastType.java @@ -34,6 +34,7 @@ public enum CastType { INTEGER, LONG, FLOAT, DOUBLE, FIXED, DATE, TIME, TIMESTAMP, OFFSET_TIMESTAMP, ZONE_TIMESTAMP, + JSON, NULL, OTHER; diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/ColumnReference.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/ColumnReference.java index f130ac493587..d4578f74d266 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/ColumnReference.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/ColumnReference.java @@ -33,9 +33,9 @@ * @author Yanming Zhou */ public class ColumnReference implements Expression, Assignable { - private final String qualifier; + private final @Nullable String qualifier; private final String columnExpression; - private final SelectablePath selectablePath; + private final @Nullable SelectablePath selectablePath; private final boolean isFormula; private final @Nullable String readExpression; private final JdbcMapping jdbcMapping; @@ -62,7 +62,7 @@ public ColumnReference(TableReference tableReference, String mapping, JdbcMappin ); } - public ColumnReference(String qualifier, SelectableMapping selectableMapping) { + public ColumnReference(@Nullable String qualifier, SelectableMapping selectableMapping) { this( qualifier, selectableMapping.getSelectionExpression(), @@ -73,7 +73,7 @@ public ColumnReference(String qualifier, SelectableMapping selectableMapping) { ); } - public ColumnReference(String qualifier, SelectableMapping selectableMapping, JdbcMapping jdbcMapping) { + public ColumnReference(@Nullable String qualifier, SelectableMapping selectableMapping, JdbcMapping jdbcMapping) { this( qualifier, selectableMapping.getSelectionExpression(), @@ -88,7 +88,7 @@ public ColumnReference( TableReference tableReference, String columnExpression, boolean isFormula, - String customReadExpression, + @Nullable String customReadExpression, JdbcMapping jdbcMapping) { this( tableReference.getIdentificationVariable(), @@ -101,20 +101,20 @@ public ColumnReference( } public ColumnReference( - String qualifier, + @Nullable String qualifier, String columnExpression, boolean isFormula, - String customReadExpression, + @Nullable String customReadExpression, JdbcMapping jdbcMapping) { this( qualifier, columnExpression, null, isFormula, customReadExpression, jdbcMapping ); } public ColumnReference( - String qualifier, + @Nullable String qualifier, String columnExpression, - SelectablePath selectablePath, + @Nullable SelectablePath selectablePath, boolean isFormula, - String customReadExpression, + @Nullable String customReadExpression, JdbcMapping jdbcMapping) { this.qualifier = nullIfEmpty( qualifier ); @@ -141,7 +141,7 @@ public ColumnReference getColumnReference() { return this; } - public String getQualifier() { + public @Nullable String getQualifier() { return qualifier; } @@ -153,11 +153,11 @@ public String getColumnExpression() { return readExpression; } - public String getSelectableName() { - return selectablePath.getSelectableName(); + public @Nullable String getSelectableName() { + return selectablePath == null ? null : selectablePath.getSelectableName(); } - public SelectablePath getSelectablePath() { + public @Nullable SelectablePath getSelectablePath() { return selectablePath; } @@ -175,7 +175,7 @@ public void appendReadExpression(SqlAppender appender) { appendReadExpression( appender, qualifier ); } - public void appendReadExpression(String qualifier, Consumer appender) { + public void appendReadExpression(@Nullable String qualifier, Consumer appender) { if ( isFormula ) { appender.accept( columnExpression ); } @@ -193,7 +193,7 @@ else if ( readExpression != null ) { } } - public void appendReadExpression(SqlAppender appender, String qualifier) { + public void appendReadExpression(SqlAppender appender, @Nullable String qualifier) { appendReadExpression( qualifier, appender::appendSql ); } @@ -201,7 +201,7 @@ public void appendColumnForWrite(SqlAppender appender) { appendColumnForWrite( appender, qualifier ); } - public void appendColumnForWrite(SqlAppender appender, String qualifier) { + public void appendColumnForWrite(SqlAppender appender, @Nullable String qualifier) { if ( qualifier != null ) { appender.append( qualifier ); appender.append( '.' ); diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/Expression.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/Expression.java index 19aff7a545c6..244bd13ba08a 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/Expression.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/Expression.java @@ -4,6 +4,7 @@ */ package org.hibernate.sql.ast.tree.expression; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.sql.ast.spi.SqlSelection; import org.hibernate.sql.ast.spi.SqlSelectionProducer; @@ -21,9 +22,9 @@ public interface Expression extends SqlAstNode, SqlSelectionProducer { /** * The type for this expression */ - JdbcMappingContainer getExpressionType(); + @Nullable JdbcMappingContainer getExpressionType(); - default ColumnReference getColumnReference() { + default @Nullable ColumnReference getColumnReference() { return null; } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java index bff937c0592e..af6b38062af7 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java @@ -347,6 +347,9 @@ static CastType getCastType(int typeCode) { return CastType.TIMESTAMP; case TIMESTAMP_WITH_TIMEZONE: return CastType.OFFSET_TIMESTAMP; + case JSON: + case JSON_ARRAY: + return CastType.JSON; case NULL: return CastType.NULL; default: diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/function/json/JsonTableTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/function/json/JsonTableTest.java index ee56cb7854a8..1a767e46da63 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/function/json/JsonTableTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/function/json/JsonTableTest.java @@ -161,6 +161,50 @@ public void testNodeBuilderJsonTableObject(SessionFactoryScope scope) { } ); } + @Test + public void testArray(SessionFactoryScope scope) { + scope.inSession( em -> { + final String query = """ + select + t.idx, + t.val + from json_table('[1,2]','$[*]' columns(val Integer path '$', idx for ordinality)) t + order by t.idx + """; + List resultList = em.createQuery( query, Tuple.class ).getResultList(); + + assertEquals( 2, resultList.size() ); + + assertEquals( 1L, resultList.get( 0 ).get( 0 ) ); + assertEquals( 1, resultList.get( 0 ).get( 1 ) ); + assertEquals( 2L, resultList.get( 1 ).get( 0 ) ); + assertEquals( 2, resultList.get( 1 ).get( 1 ) ); + } ); + } + + @Test + public void testArrayParam(SessionFactoryScope scope) { + scope.inSession( em -> { + final String query = """ + select + t.idx, + t.val + from json_table(:arr,'$[*]' columns(val Integer path '$', idx for ordinality)) t + order by t.idx + """; + List resultList = em.createQuery( query, Tuple.class ) + .setParameter( "arr", "[1,2]" ) + .getResultList(); + + assertEquals( 2, resultList.size() ); + + assertEquals( 1L, resultList.get( 0 ).get( 0 ) ); + assertEquals( 1, resultList.get( 0 ).get( 1 ) ); + assertEquals( 2L, resultList.get( 1 ).get( 0 ) ); + assertEquals( 2, resultList.get( 1 ).get( 1 ) ); + } ); + } + private static void assertTupleEquals(Tuple tuple, long arrayIndex, String arrayValue) { assertEquals( 1, tuple.get( 0 ) ); assertEquals( 0.1F, tuple.get( 1 ) ); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java index f40e6f623a72..f6c605a2c3e3 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java @@ -12,6 +12,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -126,6 +127,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java index 8a308e51bc5d..4260bea45583 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java @@ -13,6 +13,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -127,6 +128,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java index ace8ed13341e..8b4446f96281 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java @@ -8,6 +8,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -133,6 +134,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java index 123c4380a07c..239451eae15e 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java @@ -10,6 +10,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.PostgresPlusDialect; import org.hibernate.dialect.SQLServerDialect; @@ -142,6 +143,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") @SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Seems that comparing date[] through JDBC is buggy. ERROR: operator does not exist: timestamp without time zone[] = date[]") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java index 4df6ebca5171..dd73ac51486f 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java @@ -9,6 +9,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -137,6 +138,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java index cf49e56385d7..7bbebe7e8d30 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java @@ -10,6 +10,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.MySQLDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; @@ -129,6 +130,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") @SkipForDialect(dialectClass = MySQLDialect.class ) @SkipForDialect(dialectClass = DerbyDialect.class ) @SkipForDialect(dialectClass = DB2Dialect.class ) diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java index 37f1e8d9746d..ff25e7c6a5cc 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java @@ -12,6 +12,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -135,6 +136,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java index 56edcc72c81e..d15213190a4e 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java @@ -12,6 +12,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -128,6 +129,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java index e606515a6983..bea29d689744 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java @@ -8,6 +8,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -125,6 +126,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java index b15f463e5c86..ca0ca8d829e1 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java @@ -8,6 +8,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -125,6 +126,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java index 86506f8e13d6..ed8f9b8c5303 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java @@ -8,6 +8,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -130,6 +131,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java index e1bfa9778d57..d6be24f839bc 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java @@ -8,6 +8,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -125,6 +126,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java index 37f475f77d8e..dd12a5fa576b 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java @@ -8,6 +8,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -125,6 +126,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java index a6dd1e601f17..9f83633fea24 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java @@ -10,6 +10,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -139,6 +140,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java index bb7dca23e42b..8742e0dcbbcf 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java @@ -11,6 +11,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.dialect.HANADialect; import org.hibernate.dialect.HSQLDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.dialect.SybaseASEDialect; @@ -144,6 +145,7 @@ public void testNativeQueryById(SessionFactoryScope scope) { @SkipForDialect(dialectClass = SQLServerDialect.class, reason = "SQL Server requires a special function to compare XML") @SkipForDialect(dialectClass = SybaseASEDialect.class, reason = "Sybase ASE requires a special function to compare XML") @SkipForDialect(dialectClass = HANADialect.class, reason = "HANA requires a special function to compare LOBs") + @SkipForDialect(dialectClass = MariaDBDialect.class, reason = "MariaDB requires a special function to compare LOBs") public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); From f202cdc96832bfb682f9d1a429aea38d4e4312e2 Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Wed, 6 Nov 2024 10:08:11 +0100 Subject: [PATCH 3/7] HHH-18795 Add JSON aggregate support for CockroachDB --- .../dialect/CockroachLegacyDialect.java | 7 + .../community/dialect/H2LegacyDialect.java | 7 + .../hibernate/dialect/CockroachDialect.java | 7 + .../CockroachDBAggregateSupport.java | 314 ++++++++++++++++++ .../dialect/aggregate/H2AggregateSupport.java | 2 +- .../array/PostgreSQLUnnestFunction.java | 10 +- 6 files changed, 342 insertions(+), 5 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/aggregate/CockroachDBAggregateSupport.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java index c6161e95a165..27aec0cc6ab8 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java @@ -26,6 +26,8 @@ import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.*; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.CockroachDBAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.FormatFunction; import org.hibernate.dialect.function.PostgreSQLTruncFunction; @@ -699,6 +701,11 @@ public NationalizationSupport getNationalizationSupport() { return NationalizationSupport.IMPLICIT; } + @Override + public AggregateSupport getAggregateSupport() { + return CockroachDBAggregateSupport.valueOf( this ); + } + @Override public int getMaxIdentifierLength() { return 63; diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/H2LegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/H2LegacyDialect.java index b254a74f5b7f..0e0f821ccd8d 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/H2LegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/H2LegacyDialect.java @@ -20,6 +20,8 @@ import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.*; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.H2AggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.identity.H2FinalTableIdentityColumnSupport; import org.hibernate.dialect.identity.H2IdentityColumnSupport; @@ -301,6 +303,11 @@ public void contributeTypes(TypeContributions typeContributions, ServiceRegistry jdbcTypeRegistry.addDescriptor( OrdinalEnumJdbcType.INSTANCE ); } + @Override + public AggregateSupport getAggregateSupport() { + return H2AggregateSupport.valueOf( this ); + } + @Override public int getDefaultStatementBatchSize() { return 15; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java index 6fda9b7061c6..bbea3db39216 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java @@ -26,6 +26,8 @@ import org.hibernate.QueryTimeoutException; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.CockroachDBAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.FormatFunction; import org.hibernate.dialect.function.PostgreSQLTruncFunction; @@ -667,6 +669,11 @@ public NationalizationSupport getNationalizationSupport() { return NationalizationSupport.IMPLICIT; } + @Override + public AggregateSupport getAggregateSupport() { + return CockroachDBAggregateSupport.valueOf( this ); + } + @Override public int getMaxIdentifierLength() { return 63; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/CockroachDBAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/CockroachDBAggregateSupport.java new file mode 100644 index 000000000000..be72ff699282 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/CockroachDBAggregateSupport.java @@ -0,0 +1,314 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.aggregate; + +import org.hibernate.dialect.Dialect; +import org.hibernate.internal.util.StringHelper; +import org.hibernate.mapping.Column; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SelectableMapping; +import org.hibernate.metamodel.mapping.SelectablePath; +import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.sql.ast.SqlAstNodeRenderingMode; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.BasicPluralType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.hibernate.type.SqlTypes.ARRAY; +import static org.hibernate.type.SqlTypes.BIGINT; +import static org.hibernate.type.SqlTypes.BINARY; +import static org.hibernate.type.SqlTypes.BOOLEAN; +import static org.hibernate.type.SqlTypes.DOUBLE; +import static org.hibernate.type.SqlTypes.FLOAT; +import static org.hibernate.type.SqlTypes.INTEGER; +import static org.hibernate.type.SqlTypes.JSON; +import static org.hibernate.type.SqlTypes.JSON_ARRAY; +import static org.hibernate.type.SqlTypes.LONG32VARBINARY; +import static org.hibernate.type.SqlTypes.SMALLINT; +import static org.hibernate.type.SqlTypes.TINYINT; +import static org.hibernate.type.SqlTypes.VARBINARY; + +public class CockroachDBAggregateSupport extends AggregateSupportImpl { + + private static final AggregateSupport INSTANCE = new CockroachDBAggregateSupport(); + + public static AggregateSupport valueOf(Dialect dialect) { + return CockroachDBAggregateSupport.INSTANCE; + } + + @Override + public String aggregateComponentCustomReadExpression( + String template, + String placeholder, + String aggregateParentReadExpression, + String columnExpression, + int aggregateColumnTypeCode, + SqlTypedMapping column) { + switch ( aggregateColumnTypeCode ) { + case JSON_ARRAY: + case JSON: + switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case JSON: + case JSON_ARRAY: + return template.replace( + placeholder, + aggregateParentReadExpression + "->'" + columnExpression + "'" + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + // We encode binary data as hex, so we have to decode here + return template.replace( + placeholder, + "decode(" + aggregateParentReadExpression + "->>'" + columnExpression + "','hex')" + ); + case ARRAY: + final BasicPluralType pluralType = (BasicPluralType) column.getJdbcMapping(); + switch ( pluralType.getElementType().getJdbcType().getDefaultSqlTypeCode() ) { + case BOOLEAN: + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case FLOAT: + case DOUBLE: + // For types that are natively supported in jsonb we can use jsonb_array_elements, + // but note that we can't use that for string types, + // because casting a jsonb[] to text[] will not omit the quotes of the jsonb text values + return template.replace( + placeholder, + "cast(array(select jsonb_array_elements(" + aggregateParentReadExpression + "->'" + columnExpression + "')) as " + column.getColumnDefinition() + ')' + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + // We encode binary data as hex, so we have to decode here + return template.replace( + placeholder, + "array(select decode(jsonb_array_elements_text(" + aggregateParentReadExpression + "->'" + columnExpression + "'),'hex'))" + ); + default: + return template.replace( + placeholder, + "cast(array(select jsonb_array_elements_text(" + aggregateParentReadExpression + "->'" + columnExpression + "')) as " + column.getColumnDefinition() + ')' + ); + } + default: + return template.replace( + placeholder, + "cast(" + aggregateParentReadExpression + "->>'" + columnExpression + "' as " + column.getColumnDefinition() + ')' + ); + } + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) { + final int sqlTypeCode = jdbcMapping.getJdbcType().getDefaultSqlTypeCode(); + switch ( sqlTypeCode ) { + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + // We encode binary data as hex + return "to_jsonb(encode(" + customWriteExpression + ",'hex'))"; + case ARRAY: + final BasicPluralType pluralType = (BasicPluralType) jdbcMapping; + switch ( pluralType.getElementType().getJdbcType().getDefaultSqlTypeCode() ) { + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + // We encode binary data as hex + return "to_jsonb(array(select encode(unnest(" + customWriteExpression + "),'hex')))"; + default: + return "to_jsonb(" + customWriteExpression + ")"; + } + default: + return "to_jsonb(" + customWriteExpression + ")"; + } + } + + @Override + public String aggregateComponentAssignmentExpression( + String aggregateParentAssignmentExpression, + String columnExpression, + int aggregateColumnTypeCode, + Column column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + // For JSON we always have to replace the whole object + return aggregateParentAssignmentExpression; + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + @Override + public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { + switch ( aggregateSqlTypeCode ) { + case JSON: + return true; + } + return false; + } + + @Override + public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer( + SelectableMapping aggregateColumn, + SelectableMapping[] columnsToUpdate, + TypeConfiguration typeConfiguration) { + final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + switch ( aggregateSqlTypeCode ) { + case JSON: + return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate ); + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode ); + } + + private WriteExpressionRenderer jsonAggregateColumnWriter( + SelectableMapping aggregateColumn, + SelectableMapping[] columns) { + return new RootJsonWriteExpression( aggregateColumn, columns ); + } + + interface JsonWriteExpression { + void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression); + } + private static class AggregateJsonWriteExpression implements JsonWriteExpression { + private final LinkedHashMap subExpressions = new LinkedHashMap<>(); + + protected void initializeSubExpressions(SelectableMapping[] columns) { + for ( SelectableMapping column : columns ) { + final SelectablePath selectablePath = column.getSelectablePath(); + final SelectablePath[] parts = selectablePath.getParts(); + AggregateJsonWriteExpression currentAggregate = this; + for ( int i = 1; i < parts.length - 1; i++ ) { + currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent( + parts[i].getSelectableName(), + k -> new AggregateJsonWriteExpression() + ); + } + final String customWriteExpression = column.getWriteExpression(); + currentAggregate.subExpressions.put( + parts[parts.length - 1].getSelectableName(), + new BasicJsonWriteExpression( + column, + jsonCustomWriteExpression( customWriteExpression, column.getJdbcMapping() ) + ) + ); + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( "||jsonb_build_object" ); + char separator = '('; + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + final String subPath = path + "->'" + column + "'"; + sb.append( separator ); + if ( value instanceof AggregateJsonWriteExpression ) { + sb.append( '\'' ); + sb.append( column ); + sb.append( "',coalesce(" ); + sb.append( subPath ); + sb.append( ",'{}')" ); + value.append( sb, subPath, translator, expression ); + } + else { + value.append( sb, subPath, translator, expression ); + } + separator = ','; + } + sb.append( ')' ); + } + } + + private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + implements WriteExpressionRenderer { + private final boolean nullable; + private final String path; + + RootJsonWriteExpression(SelectableMapping aggregateColumn, SelectableMapping[] columns) { + this.nullable = aggregateColumn.isNullable(); + this.path = aggregateColumn.getSelectionExpression(); + initializeSubExpressions( columns ); + } + + @Override + public void render( + SqlAppender sqlAppender, + SqlAstTranslator translator, + AggregateColumnWriteExpression aggregateColumnWriteExpression, + String qualifier) { + final String basePath; + if ( qualifier == null || qualifier.isBlank() ) { + basePath = path; + } + else { + basePath = qualifier + "." + path; + } + if ( nullable ) { + sqlAppender.append( "coalesce(" ); + sqlAppender.append( basePath ); + sqlAppender.append( ",'{}')" ); + } + else { + sqlAppender.append( basePath ); + } + append( sqlAppender, basePath, translator, aggregateColumnWriteExpression ); + } + } + private static class BasicJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String customWriteExpressionStart; + private final String customWriteExpressionEnd; + + BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) { + this.selectableMapping = selectableMapping; + if ( customWriteExpression.equals( "?" ) ) { + this.customWriteExpressionStart = ""; + this.customWriteExpressionEnd = ""; + } + else { + final String[] parts = StringHelper.split( "?", customWriteExpression ); + assert parts.length == 2; + this.customWriteExpressionStart = parts[0]; + this.customWriteExpressionEnd = parts[1]; + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( '\'' ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "'," ); + sb.append( customWriteExpressionStart ); + // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, + // since we don't know how the custom write expression looks like where this is embedded, + // so we have to be pessimistic and avoid ambiguities + translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + sb.append( customWriteExpressionEnd ); + } + } + +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/H2AggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/H2AggregateSupport.java index 001b01e0c12f..134113c66e45 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/H2AggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/H2AggregateSupport.java @@ -38,7 +38,7 @@ public class H2AggregateSupport extends AggregateSupportImpl { public static @Nullable AggregateSupport valueOf(Dialect dialect) { return dialect.getVersion().isSameOrAfter( 2, 2, 220 ) ? H2AggregateSupport.INSTANCE - : null; + : AggregateSupportImpl.INSTANCE; } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/PostgreSQLUnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/PostgreSQLUnnestFunction.java index 575e19aac72a..0b11cda338a0 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/PostgreSQLUnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/PostgreSQLUnnestFunction.java @@ -45,13 +45,13 @@ protected void renderJsonTable( sqlAppender.append( ',' ); } if ( CollectionPart.Nature.INDEX.getName().equals( selectableMapping.getSelectableName() ) ) { - sqlAppender.appendSql( "t.ordinality" ); + sqlAppender.appendSql( "t.i" ); } else { sqlAppender.append( aggregateSupport.aggregateComponentCustomReadExpression( "", "", - "t.value", + "t.v", selectableMapping.getSelectableName(), SqlTypes.JSON, selectableMapping @@ -64,8 +64,10 @@ protected void renderJsonTable( array.accept( walker ); sqlAppender.appendSql( ')' ); if ( tupleType.findSubPart( CollectionPart.Nature.INDEX.getName(), null ) != null ) { - sqlAppender.appendSql( " with ordinality" ); + sqlAppender.appendSql( " with ordinality t(v,i))" ); + } + else { + sqlAppender.appendSql( " t(v))" ); } - sqlAppender.appendSql( " t)" ); } } From 5478209f24ec27bf186b6c98d2f2c496697f78ae Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Wed, 6 Nov 2024 17:29:29 +0100 Subject: [PATCH 4/7] HHH-18796 Add JSON aggregate support for DB2 --- .../community/dialect/DB2LegacyDialect.java | 8 +- .../org/hibernate/dialect/DB2Dialect.java | 8 +- .../aggregate/DB2AggregateSupport.java | 302 +++++++++++++++++- .../function/CommonFunctionFactory.java | 11 +- .../function/CteGenerateSeriesFunction.java | 6 +- .../function/array/DB2UnnestFunction.java | 156 +++++++++ .../function/json/DB2JsonTableFunction.java | 189 ++++++++--- .../function/json/JsonTableFunction.java | 2 +- .../org/hibernate/mapping/BasicValue.java | 2 +- .../descriptor/java/JdbcDateJavaType.java | 2 +- .../descriptor/java/JdbcTimeJavaType.java | 2 +- .../java/JdbcTimestampJavaType.java | 2 +- .../java/OffsetDateTimeJavaType.java | 27 +- 13 files changed, 646 insertions(+), 71 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java index 4acbcc1e192c..e28b49c3f3bd 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java @@ -439,7 +439,7 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio functionFactory.jsonArray_db2(); functionFactory.jsonArrayAgg_db2(); functionFactory.jsonObjectAgg_db2(); - functionFactory.jsonTable_db2(); + functionFactory.jsonTable_db2( getMaximumSeriesSize() ); } } @@ -459,7 +459,7 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio functionFactory.xmlagg(); functionFactory.xmltable_db2(); - functionFactory.unnest_emulated(); + functionFactory.unnest_db2( getMaximumSeriesSize() ); if ( supportsRecursiveCTE() ) { functionFactory.generateSeries_recursive( getMaximumSeriesSize(), false, true ); } @@ -1007,7 +1007,9 @@ public ValueExtractor getExtractor(JavaType javaType) { @Override public AggregateSupport getAggregateSupport() { - return DB2AggregateSupport.INSTANCE; + return getDB2Version().isSameOrAfter( 11 ) + ? DB2AggregateSupport.JSON_INSTANCE + : DB2AggregateSupport.INSTANCE; } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java index f1202b3a183c..c6b2dfde742f 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java @@ -410,7 +410,7 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio functionFactory.jsonArray_db2(); functionFactory.jsonArrayAgg_db2(); functionFactory.jsonObjectAgg_db2(); - functionFactory.jsonTable_db2(); + functionFactory.jsonTable_db2( getMaximumSeriesSize() ); } functionFactory.xmlelement(); @@ -429,7 +429,7 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio functionFactory.xmlagg(); functionFactory.xmltable_db2(); - functionFactory.unnest_emulated(); + functionFactory.unnest_db2( getMaximumSeriesSize() ); functionFactory.generateSeries_recursive( getMaximumSeriesSize(), false, true ); } @@ -1066,7 +1066,9 @@ public ValueExtractor getExtractor(JavaType javaType) { @Override public AggregateSupport getAggregateSupport() { - return DB2AggregateSupport.INSTANCE; + return getDB2Version().isSameOrAfter( 11 ) + ? DB2AggregateSupport.JSON_INSTANCE + : DB2AggregateSupport.INSTANCE; } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/DB2AggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/DB2AggregateSupport.java index 8a1a5567fc31..bfbb9c7c4820 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/DB2AggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/DB2AggregateSupport.java @@ -23,6 +23,7 @@ import org.hibernate.mapping.AggregateColumn; import org.hibernate.mapping.Column; import org.hibernate.metamodel.mapping.EmbeddableMappingType; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.SelectableMapping; import org.hibernate.metamodel.mapping.SelectablePath; import org.hibernate.metamodel.mapping.SqlExpressible; @@ -31,16 +32,36 @@ import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.AggregateJdbcType; import org.hibernate.type.descriptor.sql.DdlType; import org.hibernate.type.spi.TypeConfiguration; +import static org.hibernate.type.SqlTypes.ARRAY; +import static org.hibernate.type.SqlTypes.BINARY; +import static org.hibernate.type.SqlTypes.BLOB; import static org.hibernate.type.SqlTypes.BOOLEAN; +import static org.hibernate.type.SqlTypes.JSON; +import static org.hibernate.type.SqlTypes.JSON_ARRAY; +import static org.hibernate.type.SqlTypes.LONG32VARBINARY; import static org.hibernate.type.SqlTypes.SMALLINT; import static org.hibernate.type.SqlTypes.STRUCT; +import static org.hibernate.type.SqlTypes.TIME; +import static org.hibernate.type.SqlTypes.TIMESTAMP; +import static org.hibernate.type.SqlTypes.TIMESTAMP_UTC; +import static org.hibernate.type.SqlTypes.TIMESTAMP_WITH_TIMEZONE; +import static org.hibernate.type.SqlTypes.VARBINARY; public class DB2AggregateSupport extends AggregateSupportImpl { - public static final AggregateSupport INSTANCE = new DB2AggregateSupport(); + public static final AggregateSupport INSTANCE = new DB2AggregateSupport( false ); + public static final AggregateSupport JSON_INSTANCE = new DB2AggregateSupport( true ); + + private final boolean jsonSupport; + + public DB2AggregateSupport(boolean jsonSupport) { + this.jsonSupport = jsonSupport; + } @Override public String aggregateComponentCustomReadExpression( @@ -51,12 +72,83 @@ public String aggregateComponentCustomReadExpression( int aggregateColumnTypeCode, SqlTypedMapping column) { switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + if ( !jsonSupport ) { + break; + } + switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case BOOLEAN: + if ( SqlTypes.isNumericType( column.getJdbcMapping().getJdbcType().getDdlTypeCode() ) ) { + return template.replace( + placeholder, + "decode(json_value(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),'true',1,'false',0)" + ); + } + else { + return template.replace( + placeholder, + "decode(json_value(" + aggregateParentReadExpression + ",'$." + columnExpression + "'),'true',true,'false',false)" + ); + } + case TIMESTAMP_WITH_TIMEZONE: + case TIMESTAMP_UTC: + return template.replace( + placeholder, + "cast(trim(trailing 'Z' from json_value(" + aggregateParentReadExpression + ",'$." + columnExpression + "' returning varchar(35))) as " + column.getColumnDefinition() + ")" + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex, so we have to decode here + return template.replace( + placeholder, + "hextoraw(json_value(" + aggregateParentReadExpression + ",'$." + columnExpression + "'))" + ); + case JSON: + case JSON_ARRAY: + return template.replace( + placeholder, + "json_query(" + aggregateParentReadExpression + ",'$." + columnExpression + "')" + ); + default: + return template.replace( + placeholder, + "json_value(" + aggregateParentReadExpression + ",'$." + columnExpression + "' returning " + column.getColumnDefinition() + ")" + ); + } case STRUCT: return template.replace( placeholder, aggregateParentReadExpression + ".." + columnExpression ); } throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); } + private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) { + final int sqlTypeCode = jdbcMapping.getJdbcType().getDefaultSqlTypeCode(); + switch ( sqlTypeCode ) { + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex + return "hex(" + customWriteExpression + ")"; + case ARRAY: + case JSON_ARRAY: + return "(" + customWriteExpression + ") format json"; +// case BOOLEAN: +// return "(" + customWriteExpression + ")=true"; + case TIME: + return "varchar_format(timestamp('1970-01-01'," + customWriteExpression + "),'HH24:MI:SS')"; + case TIMESTAMP: + return "replace(varchar_format(" + customWriteExpression + ",'YYYY-MM-DD HH24:MI:SS.FF9'),' ','T')"; + case TIMESTAMP_UTC: + return "replace(varchar_format(" + customWriteExpression + ",'YYYY-MM-DD HH24:MI:SS.FF9'),' ','T')||'Z'"; + default: + return customWriteExpression; + } + } + @Override public String aggregateComponentAssignmentExpression( String aggregateParentAssignmentExpression, @@ -64,6 +156,13 @@ public String aggregateComponentAssignmentExpression( int aggregateColumnTypeCode, Column column) { switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + if ( jsonSupport ) { + // For JSON we always have to replace the whole object + return aggregateParentAssignmentExpression; + } + break; case STRUCT: return aggregateParentAssignmentExpression + ".." + columnExpression; } @@ -74,7 +173,16 @@ public String aggregateComponentAssignmentExpression( public String aggregateCustomWriteExpression( AggregateColumn aggregateColumn, List aggregatedColumns) { - switch ( aggregateColumn.getTypeCode() ) { + // We need to know what array this is STRUCT_ARRAY/JSON_ARRAY/XML_ARRAY, + // which we can easily get from the type code of the aggregate column + final int sqlTypeCode = aggregateColumn.getType().getJdbcType().getDefaultSqlTypeCode(); + switch ( sqlTypeCode == SqlTypes.ARRAY ? aggregateColumn.getTypeCode() : sqlTypeCode ) { + case JSON: + case JSON_ARRAY: + if ( jsonSupport ) { + return null; + } + break; case STRUCT: final StringBuilder sb = new StringBuilder(); appendStructCustomWriteExpression( aggregateColumn, aggregatedColumns, sb ); @@ -107,16 +215,21 @@ private static void appendStructCustomWriteExpression( @Override public int aggregateComponentSqlTypeCode(int aggregateColumnSqlTypeCode, int columnSqlTypeCode) { - if ( aggregateColumnSqlTypeCode == STRUCT && columnSqlTypeCode == BOOLEAN ) { + if ( aggregateColumnSqlTypeCode == STRUCT ) { // DB2 doesn't support booleans in structs - return SMALLINT; + return columnSqlTypeCode == BOOLEAN ? SMALLINT : columnSqlTypeCode; + } + else if ( aggregateColumnSqlTypeCode == JSON ) { + return columnSqlTypeCode == ARRAY ? JSON_ARRAY : columnSqlTypeCode; + } + else { + return columnSqlTypeCode; } - return columnSqlTypeCode; } @Override public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { - return aggregateSqlTypeCode == STRUCT; + return aggregateSqlTypeCode == STRUCT || aggregateSqlTypeCode == JSON; } @Override @@ -126,12 +239,23 @@ public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer( TypeConfiguration typeConfiguration) { final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); switch ( aggregateSqlTypeCode ) { + case JSON: + if ( jsonSupport ) { + return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate ); + } + break; case STRUCT: return structAggregateColumnWriter( aggregateColumn, columnsToUpdate, typeConfiguration ); } throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode ); } + private WriteExpressionRenderer jsonAggregateColumnWriter( + SelectableMapping aggregateColumn, + SelectableMapping[] columns) { + return new RootJsonWriteExpression( aggregateColumn, columns ); + } + private WriteExpressionRenderer structAggregateColumnWriter( SelectableMapping aggregateColumn, SelectableMapping[] columns, @@ -473,4 +597,170 @@ private static boolean needsVarcharForBitDataCast(String columnType) { || columTypeLC.startsWith( "char" ) && columTypeLC.endsWith( " bit data" ); } + interface JsonWriteExpression { + void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression); + } + private static class AggregateJsonWriteExpression implements JsonWriteExpression { + private final LinkedHashMap subExpressions = new LinkedHashMap<>(); + + protected void initializeSubExpressions(SelectableMapping aggregateColumn, SelectableMapping[] columns) { + for ( SelectableMapping column : columns ) { + final SelectablePath selectablePath = column.getSelectablePath(); + final SelectablePath[] parts = selectablePath.getParts(); + AggregateJsonWriteExpression currentAggregate = this; + for ( int i = 1; i < parts.length - 1; i++ ) { + currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent( + parts[i].getSelectableName(), + k -> new AggregateJsonWriteExpression() + ); + } + final String customWriteExpression = column.getWriteExpression(); + currentAggregate.subExpressions.put( + parts[parts.length - 1].getSelectableName(), + new BasicJsonWriteExpression( + column, + jsonCustomWriteExpression( customWriteExpression, column.getJdbcMapping() ) + ) + ); + } + passThroughUnsetSubExpressions( aggregateColumn ); + } + + protected void passThroughUnsetSubExpressions(SelectableMapping aggregateColumn) { + final AggregateJdbcType aggregateJdbcType = (AggregateJdbcType) aggregateColumn.getJdbcMapping().getJdbcType(); + final EmbeddableMappingType embeddableMappingType = aggregateJdbcType.getEmbeddableMappingType(); + final int jdbcValueCount = embeddableMappingType.getJdbcValueCount(); + for ( int i = 0; i < jdbcValueCount; i++ ) { + final SelectableMapping selectableMapping = embeddableMappingType.getJdbcValueSelectable( i ); + + final JsonWriteExpression jsonWriteExpression = subExpressions.get( selectableMapping.getSelectableName() ); + if ( jsonWriteExpression == null ) { + subExpressions.put( + selectableMapping.getSelectableName(), + new PassThroughExpression( selectableMapping ) + ); + } + else if ( jsonWriteExpression instanceof AggregateJsonWriteExpression writeExpression ) { + writeExpression.passThroughUnsetSubExpressions( selectableMapping ); + } + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( "json_object" ); + char separator = '('; + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + final String subPath = "json_query(" + path + ",'$." + column + "') format json"; + sb.append( separator ); + if ( value instanceof AggregateJsonWriteExpression ) { + sb.append( '\'' ); + sb.append( column ); + sb.append( "' value coalesce(" ); + value.append( sb, subPath, translator, expression ); + sb.append( ",json_object())" ); + } + else { + value.append( sb, subPath, translator, expression ); + } + separator = ','; + } + sb.append( ')' ); + } + } + + private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + implements WriteExpressionRenderer { + private final String path; + + RootJsonWriteExpression(SelectableMapping aggregateColumn, SelectableMapping[] columns) { + this.path = aggregateColumn.getSelectionExpression(); + initializeSubExpressions( aggregateColumn, columns ); + } + + @Override + public void render( + SqlAppender sqlAppender, + SqlAstTranslator translator, + AggregateColumnWriteExpression aggregateColumnWriteExpression, + String qualifier) { + final String basePath; + if ( qualifier == null || qualifier.isBlank() ) { + basePath = path; + } + else { + basePath = qualifier + "." + path; + } + append( sqlAppender, basePath, translator, aggregateColumnWriteExpression ); + } + } + private static class BasicJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String customWriteExpressionStart; + private final String customWriteExpressionEnd; + + BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) { + this.selectableMapping = selectableMapping; + if ( customWriteExpression.equals( "?" ) ) { + this.customWriteExpressionStart = ""; + this.customWriteExpressionEnd = ""; + } + else { + final String[] parts = StringHelper.split( "?", customWriteExpression ); + assert parts.length == 2; + this.customWriteExpressionStart = parts[0]; + this.customWriteExpressionEnd = parts[1]; + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( '\'' ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "' value " ); + sb.append( customWriteExpressionStart ); + // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, + // since we don't know how the custom write expression looks like where this is embedded, + // so we have to be pessimistic and avoid ambiguities + translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + sb.append( customWriteExpressionEnd ); + } + } + + private static class PassThroughExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + + PassThroughExpression(SelectableMapping selectableMapping) { + this.selectableMapping = selectableMapping; + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( '\'' ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "' value " ); + sb.append( path ); + } + } + } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java index 7e2b8772cf86..23a60ad3d129 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java @@ -4225,6 +4225,13 @@ public void unnest_hana() { functionRegistry.register( "unnest", new HANAUnnestFunction() ); } + /** + * DB2 unnest() function + */ + public void unnest_db2(int maximumArraySize) { + functionRegistry.register( "unnest", new DB2UnnestFunction( maximumArraySize ) ); + } + /** * Standard generate_series() function */ @@ -4305,8 +4312,8 @@ public void jsonTable_mysql() { /** * DB2 json_table() function */ - public void jsonTable_db2() { - functionRegistry.register( "json_table", new DB2JsonTableFunction( typeConfiguration ) ); + public void jsonTable_db2(int maximumSeriesSize) { + functionRegistry.register( "json_table", new DB2JsonTableFunction( maximumSeriesSize, typeConfiguration ) ); } /** diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CteGenerateSeriesFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CteGenerateSeriesFunction.java index 2b1bb5435341..2f897c6db29a 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CteGenerateSeriesFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CteGenerateSeriesFunction.java @@ -126,7 +126,7 @@ public TableGroup convertToSqlAst(NavigablePath navigablePath, String identifier }; } - protected static class CteGenerateSeriesQueryTransformer extends NumberSeriesQueryTransformer { + public static class CteGenerateSeriesQueryTransformer extends NumberSeriesQueryTransformer { public static final String NAME = "max_series"; protected final int maxSeriesSize; @@ -146,6 +146,10 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo } protected CteStatement createSeriesCte(SqmToSqlAstConverter converter) { + return createSeriesCte( maxSeriesSize, converter ); + } + + public static CteStatement createSeriesCte(int maxSeriesSize, SqmToSqlAstConverter converter) { final BasicType longType = converter.getCreationContext().getTypeConfiguration() .getBasicTypeForJavaType( Long.class ); final Expression one = new UnparsedNumericLiteral<>( "1", NumericTypeCategory.LONG, longType ); diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java new file mode 100644 index 000000000000..9ccc8817192c --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java @@ -0,0 +1,156 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.function.array; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.function.CteGenerateSeriesFunction; +import org.hibernate.dialect.function.json.DB2JsonTableFunction; +import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.metamodel.mapping.BasicValuedModelPart; +import org.hibernate.metamodel.mapping.CollectionPart; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.ModelPart; +import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.query.derived.AnonymousTupleTableGroupProducer; +import org.hibernate.query.spi.QueryEngine; +import org.hibernate.query.sqm.function.SelfRenderingSqmSetReturningFunction; +import org.hibernate.query.sqm.sql.SqmToSqlAstConverter; +import org.hibernate.query.sqm.tree.SqmTypedNode; +import org.hibernate.spi.NavigablePath; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.sql.ast.tree.from.TableGroup; +import org.hibernate.type.BasicPluralType; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; + +import java.util.List; + +/** + * DB2 unnest function. + * Unnesting JSON arrays requires more effort since DB2 doesn't support arrays in {@code json_table()}. + * See {@link org.hibernate.dialect.function.json.DB2JsonTableFunction} for more details. + * + * @see org.hibernate.dialect.function.json.DB2JsonTableFunction + */ +public class DB2UnnestFunction extends UnnestFunction { + + private final int maximumArraySize; + + public DB2UnnestFunction(int maximumArraySize) { + super( "v", "i" ); + this.maximumArraySize = maximumArraySize; + } + + @Override + protected SelfRenderingSqmSetReturningFunction generateSqmSetReturningFunctionExpression(List> arguments, QueryEngine queryEngine) { + return new SelfRenderingSqmSetReturningFunction<>( + this, + this, + arguments, + getArgumentsValidator(), + getSetReturningTypeResolver(), + queryEngine.getCriteriaBuilder(), + getName() + ) { + @Override + public TableGroup convertToSqlAst(NavigablePath navigablePath, String identifierVariable, boolean lateral, boolean canUseInnerJoins, boolean withOrdinality, SqmToSqlAstConverter walker) { + walker.registerQueryTransformer( new DB2JsonTableFunction.SeriesQueryTransformer( maximumArraySize ) ); + return super.convertToSqlAst( navigablePath, identifierVariable, lateral, canUseInnerJoins, withOrdinality, walker ); + } + }; + } + + @Override + protected void renderJsonTable( + SqlAppender sqlAppender, + Expression array, + BasicPluralType pluralType, + @Nullable SqlTypedMapping sqlTypedMapping, + AnonymousTupleTableGroupProducer tupleType, + String tableIdentifierVariable, + SqlAstTranslator walker) { + sqlAppender.appendSql( "lateral(select " ); + final ModelPart elementPart = tupleType.findSubPart( CollectionPart.Nature.ELEMENT.getName(), null ); + if ( elementPart == null ) { + sqlAppender.append( "t.*" ); + } + else { + final BasicValuedModelPart elementMapping = elementPart.asBasicValuedModelPart(); + final boolean isBoolean = elementMapping.getSingleJdbcMapping().getJdbcType().isBoolean(); + if ( isBoolean ) { + sqlAppender.appendSql( "decode(" ); + } + sqlAppender.appendSql( "json_value('{\"a\":'||" ); + array.accept( walker ); + sqlAppender.appendSql( "||'}','$.a['||(i.i-1)||']'" ); + if ( isBoolean ) { + sqlAppender.appendSql( ')' ); + final JdbcMapping type = elementMapping.getSingleJdbcMapping(); + //noinspection unchecked + final JdbcLiteralFormatter jdbcLiteralFormatter = type.getJdbcLiteralFormatter(); + final SessionFactoryImplementor sessionFactory = walker.getSessionFactory(); + final Dialect dialect = sessionFactory.getJdbcServices().getDialect(); + final WrapperOptions wrapperOptions = sessionFactory.getWrapperOptions(); + final Object trueValue = type.convertToRelationalValue( true ); + final Object falseValue = type.convertToRelationalValue( false ); + sqlAppender.append( ",'true'," ); + jdbcLiteralFormatter.appendJdbcLiteral( sqlAppender, trueValue, dialect, wrapperOptions ); + sqlAppender.append( ",'false'," ); + jdbcLiteralFormatter.appendJdbcLiteral( sqlAppender, falseValue, dialect, wrapperOptions ); + sqlAppender.append( ") " ); + } + else { + sqlAppender.appendSql( " returning " ); + sqlAppender.append( getDdlType( elementMapping, walker ) ); + sqlAppender.append( ") " ); + } + + sqlAppender.append( elementMapping.getSelectionExpression() ); + } + final ModelPart indexPart = tupleType.findSubPart( CollectionPart.Nature.INDEX.getName(), null ); + if ( indexPart != null ) { + sqlAppender.appendSql( ",i.i " ); + sqlAppender.append( indexPart.asBasicValuedModelPart().getSelectionExpression() ); + } + + sqlAppender.appendSql( " from " ); + sqlAppender.appendSql( CteGenerateSeriesFunction.CteGenerateSeriesQueryTransformer.NAME ); + sqlAppender.appendSql( " i" ); + + if ( elementPart == null ) { + sqlAppender.appendSql( " join json_table(json_query('{\"a\":'||" ); + array.accept( walker ); + sqlAppender.appendSql( "||'}','$.a['||(i.i-1)||']'),'strict $' columns(" ); + tupleType.forEachSelectable( 0, (selectionIndex, selectableMapping) -> { + if ( !CollectionPart.Nature.INDEX.getName().equals( selectableMapping.getSelectableName() ) ) { + if ( selectionIndex == 0 ) { + sqlAppender.append( ' ' ); + } + else { + sqlAppender.append( ',' ); + } + sqlAppender.append( selectableMapping.getSelectionExpression() ); + sqlAppender.append( ' ' ); + sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.appendSql( " path '$." ); + sqlAppender.append( selectableMapping.getSelectableName() ); + sqlAppender.appendSql( '\'' ); + } + } ); + sqlAppender.appendSql( ") error on error) t on json_exists('{\"a\":'||" ); + array.accept( walker ); + sqlAppender.appendSql( "||'}','$.a['||(i.i-1)||']'))" ); + } + else { + sqlAppender.appendSql( " where json_exists('{\"a\":'||" ); + array.accept( walker ); + sqlAppender.appendSql( "||'}','$.a['||(i.i-1)||']'))" ); + } + + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/DB2JsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/DB2JsonTableFunction.java index 0135028cdc20..b952aaaf0eff 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/DB2JsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/DB2JsonTableFunction.java @@ -6,9 +6,18 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.QueryException; +import org.hibernate.dialect.function.CteGenerateSeriesFunction; import org.hibernate.query.derived.AnonymousTupleTableGroupProducer; +import org.hibernate.query.spi.QueryEngine; +import org.hibernate.query.sqm.function.SelfRenderingSqmSetReturningFunction; +import org.hibernate.query.sqm.sql.SqmToSqlAstConverter; +import org.hibernate.query.sqm.tree.SqmTypedNode; +import org.hibernate.query.sqm.tree.expression.SqmExpression; +import org.hibernate.query.sqm.tree.expression.SqmJsonTableFunction; +import org.hibernate.spi.NavigablePath; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.cte.CteContainer; import org.hibernate.sql.ast.tree.expression.CastTarget; import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.expression.JsonExistsErrorBehavior; @@ -23,23 +32,60 @@ import org.hibernate.sql.ast.tree.expression.JsonTableValueColumnDefinition; import org.hibernate.sql.ast.tree.expression.JsonValueEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonValueErrorBehavior; +import org.hibernate.sql.ast.tree.expression.Literal; +import org.hibernate.sql.ast.tree.expression.QueryTransformer; +import org.hibernate.sql.ast.tree.from.FunctionTableGroup; +import org.hibernate.sql.ast.tree.from.TableGroup; +import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.type.SqlTypes; import org.hibernate.type.spi.TypeConfiguration; +import java.util.List; + /** * DB2 json_table function. * This implementation/emulation goes to great lengths to ensure Hibernate ORM can provide the same {@code json_table()} * experience that other dialects provide also on DB2. * The most notable limitation of the DB2 function is that it doesn't support JSON arrays, - * so this emulation uses a series CTE called {@code gen_} with 10_000 rows to join + * so this emulation uses a series CTE called {@code max_series} with 10_000 rows to join * each array element queried with {@code json_query()} at the respective index via {@code json_table()} separately. * Another notable limitation of the DB2 function is that it doesn't support nested column paths, * which requires emulation by joining each nesting with a separate {@code json_table()}. */ public class DB2JsonTableFunction extends JsonTableFunction { - public DB2JsonTableFunction(TypeConfiguration typeConfiguration) { + private final int maximumSeriesSize; + + public DB2JsonTableFunction(int maximumSeriesSize, TypeConfiguration typeConfiguration) { super( typeConfiguration ); + this.maximumSeriesSize = maximumSeriesSize; + } + + @Override + protected SelfRenderingSqmSetReturningFunction generateSqmSetReturningFunctionExpression(List> sqmArguments, QueryEngine queryEngine) { + //noinspection unchecked + return new SqmJsonTableFunction<>( + this, + this, + getArgumentsValidator(), + getSetReturningTypeResolver(), + queryEngine.getCriteriaBuilder(), + (SqmExpression) sqmArguments.get( 0 ), + sqmArguments.size() > 1 ? (SqmExpression) sqmArguments.get( 1 ) : null + ) { + @Override + public TableGroup convertToSqlAst(NavigablePath navigablePath, String identifierVariable, boolean lateral, boolean canUseInnerJoins, boolean withOrdinality, SqmToSqlAstConverter walker) { + final FunctionTableGroup tableGroup = (FunctionTableGroup) super.convertToSqlAst( navigablePath, identifierVariable, lateral, canUseInnerJoins, withOrdinality, walker ); + final JsonTableArguments arguments = JsonTableArguments.extract( tableGroup.getPrimaryTableReference().getFunctionExpression().getArguments() ); + final Expression jsonPath = arguments.jsonPath(); + final boolean isArray = !(jsonPath instanceof Literal literal) + || isArrayAccess( (String) literal.getLiteralValue() ); + if ( isArray || hasNestedArray( arguments.columnsClause() ) ) { + walker.registerQueryTransformer( new SeriesQueryTransformer( maximumSeriesSize ) ); + } + return tableGroup; + } + }; } @Override @@ -55,22 +101,13 @@ protected void renderJsonTable( final Expression jsonDocument = arguments.jsonDocument(); final Expression jsonPath = arguments.jsonPath(); final boolean isArray = isArrayAccess( jsonPath, walker ); - sqlAppender.appendSql( "lateral(" ); - - if ( isArray || hasNestedArray( arguments.columnsClause() ) ) { - // DB2 doesn't support arrays in json_table(), so a series table to join individual elements is needed - sqlAppender.appendSql( "with gen_(v) as(select 0 from (values (0)) union all " ); - sqlAppender.appendSql( "select i.v+1 from gen_ i where i.v<10000)" ); - } - - sqlAppender.appendSql( "select" ); + sqlAppender.appendSql( "lateral(select" ); renderColumnSelects( sqlAppender, arguments.columnsClause(), 0, isArray ); + sqlAppender.appendSql( " from " ); if ( isArray ) { - sqlAppender.appendSql( " from gen_ i join " ); - } - else { - sqlAppender.appendSql( " from " ); + sqlAppender.appendSql( CteGenerateSeriesFunction.CteGenerateSeriesQueryTransformer.NAME ); + sqlAppender.appendSql( " i join " ); } sqlAppender.appendSql( "json_table(" ); // DB2 json functions only work when passing object documents, @@ -87,8 +124,33 @@ protected void renderJsonTable( sqlAppender.appendSql( " error on error) t0" ); if ( isArray ) { sqlAppender.appendSql( " on json_exists('{\"a\":'||" ); - appendJsonDocument( sqlAppender, jsonPath, jsonDocument, arguments.passingClause(), isArray, walker ); - sqlAppender.appendSql( "||'}','$.a['||i.v||']')" ); + if ( jsonPath != null ) { + final String jsonPathString; + if ( arguments.passingClause() != null ) { + jsonPathString = JsonPathHelper.inlinedJsonPathIncludingPassingClause( jsonPath, arguments.passingClause(), walker ); + } + else { + jsonPathString = walker.getLiteralValue( jsonPath ); + } + if ( jsonPathString.endsWith( "[*]" ) ) { + jsonDocument.accept( walker ); + sqlAppender.appendSql( "||'}'," ); + final String adaptedJsonPath = jsonPathString.substring( 0, jsonPathString.length() - 3 ); + sqlAppender.appendSingleQuoteEscapedString( adaptedJsonPath.replace( "$", "$.a" ) ); + sqlAppender.appendSql( "||'['||(i.i-1)||']')" ); + } + else { + sqlAppender.appendSql( "json_query('{\"a\":'||" ); + jsonDocument.accept( walker ); + sqlAppender.appendSql( "||'}'," ); + sqlAppender.appendSingleQuoteEscapedString( jsonPathString.replace( "$", "$.a" ) ); + sqlAppender.appendSql( " with wrapper)||'}','$.a['||(i.i-1)||']')" ); + } + } + else { + jsonDocument.accept( walker ); + sqlAppender.appendSql( "||'}','$.a['||(i.i-1)||']')" ); + } } renderNestedColumnJoins( sqlAppender, arguments.columnsClause(), 0, walker ); sqlAppender.appendSql( ')' ); @@ -97,27 +159,57 @@ protected void renderJsonTable( private static void appendJsonDocument(SqlAppender sqlAppender, Expression jsonPath, Expression jsonDocument, JsonPathPassingClause passingClause, boolean isArray, SqlAstTranslator walker) { if ( jsonPath != null ) { sqlAppender.appendSql( "json_query(" ); - jsonDocument.accept( walker ); - sqlAppender.appendSql( ',' ); - if ( passingClause != null ) { - JsonPathHelper.appendInlinedJsonPathIncludingPassingClause( - sqlAppender, - "", - jsonPath, - passingClause, - walker - ); + if ( isArray ) { + final String jsonPathString; + if ( passingClause != null ) { + jsonPathString = JsonPathHelper.inlinedJsonPathIncludingPassingClause( jsonPath, passingClause, walker ); + } + else { + jsonPathString = walker.getLiteralValue( jsonPath ); + } + if ( jsonPathString.endsWith( "[*]" ) ) { + sqlAppender.appendSql( "'{\"a\":'||" ); + jsonDocument.accept( walker ); + sqlAppender.appendSql( "||'}'," ); + final String adaptedJsonPath = jsonPathString.substring( 0, jsonPathString.length() - 3 ); + sqlAppender.appendSingleQuoteEscapedString( adaptedJsonPath.replace( "$", "$.a" ) ); + sqlAppender.appendSql( "||'['||(i.i-1)||']'" ); + } + else { + sqlAppender.appendSql( "'{\"a\":'||" ); + sqlAppender.appendSql( "json_query('{\"a\":'||" ); + jsonDocument.accept( walker ); + sqlAppender.appendSql( "||'}'," ); + sqlAppender.appendSingleQuoteEscapedString( jsonPathString.replace( "$", "$.a" ) ); + sqlAppender.appendSql( " with wrapper)||'}','$.a['||(i.i-1)||']'" ); + } } else { - jsonPath.accept( walker ); - } - if ( isArray ) { - sqlAppender.appendSql( " with wrapper" ); + jsonDocument.accept( walker ); + sqlAppender.appendSql( ',' ); + if ( passingClause != null ) { + JsonPathHelper.appendInlinedJsonPathIncludingPassingClause( + sqlAppender, + "", + jsonPath, + passingClause, + walker + ); + } + else { + jsonPath.accept( walker ); + } } sqlAppender.appendSql( ')' ); } else { + if ( isArray ) { + sqlAppender.appendSql( "json_query('{\"a\":'||" ); + } jsonDocument.accept( walker ); + if ( isArray ) { + sqlAppender.appendSql( "||'}','$.a['||(i.i-1)||']')" ); + } } } @@ -161,32 +253,33 @@ private int renderNestedColumnJoins(SqlAppender sqlAppender, JsonTableColumnsCla sqlAppender.appendSql( " left join lateral (select" ); renderColumnSelects( sqlAppender, nestedColumnDefinition.columns(), nextClauseLevel, isArray ); - sqlAppender.appendSql( " from" ); + sqlAppender.appendSql( " from " ); if ( isArray ) { // When the JSON path indicates that the document is an array, - // join the `gen_` CTE to be able to use the respective array element in json_table(). + // join the `max_series` CTE to be able to use the respective array element in json_table(). // DB2 json functions only work when passing object documents, // which is why results are packed in shell object `{"a":...}` - sqlAppender.appendSql( " gen_ i join json_table('{\"a\":'||json_query('{\"a\":'||t" ); + sqlAppender.appendSql( CteGenerateSeriesFunction.CteGenerateSeriesQueryTransformer.NAME ); + sqlAppender.appendSql( " i join json_table('{\"a\":'||json_query('{\"a\":'||t" ); sqlAppender.appendSql( clauseLevel ); sqlAppender.appendSql( ".nested_" ); sqlAppender.appendSql( nextClauseLevel ); - sqlAppender.appendSql( "_||'}','$.a['||i.v||']')||'}','strict $'" ); + sqlAppender.appendSql( "_||'}','$.a['||(i.i-1)||']')||'}','strict $'" ); // Since the query results are packed in a shell object `{"a":...}`, // the JSON path for columns need to be prefixed with `$.a` renderColumns( sqlAppender, nestedColumnDefinition.columns(), nextClauseLevel, "$.a", walker ); sqlAppender.appendSql( " error on error) t" ); sqlAppender.appendSql( nextClauseLevel ); - // Emulation of arrays via `gen_` sequence requires a join condition to check if an array element exists + // Emulation of arrays via `max_series` sequence requires a join condition to check if an array element exists sqlAppender.appendSql( " on json_exists('{\"a\":'||t" ); sqlAppender.appendSql( clauseLevel ); sqlAppender.appendSql( ".nested_" ); sqlAppender.appendSql( nextClauseLevel ); - sqlAppender.appendSql( "_||'}','$.a['||i.v||']')" ); + sqlAppender.appendSql( "_||'}','$.a['||(i.i-1)||']')" ); } else { - sqlAppender.appendSql( " json_table(t" ); + sqlAppender.appendSql( "json_table(t" ); sqlAppender.appendSql( clauseLevel ); sqlAppender.appendSql( ".nested_" ); sqlAppender.appendSql( nextClauseLevel ); @@ -237,8 +330,7 @@ else if ( columnDefinition instanceof JsonTableOrdinalityColumnDefinition ordina // DB2 doesn't support the for ordinality syntax in json_table() since it has no support for array either if ( isArray ) { // If the document is an array, a series table with alias `i` is joined to emulate array support. - // Since the value of the series is 0 based, we add 1 to obtain the ordinality value - sqlAppender.appendSql( "i.v+1 " ); + sqlAppender.appendSql( "i.i " ); } else { // The ordinality for non-array documents always is trivially 1 @@ -435,4 +527,21 @@ private void renderJsonExistsColumnDefinition(SqlAppender sqlAppender, JsonTable sqlAppender.appendSql( definition.name() ); sqlAppender.appendSql( " clob format json path '$'" ); } + + public static class SeriesQueryTransformer implements QueryTransformer { + + private final int maxSeriesSize; + + public SeriesQueryTransformer(int maxSeriesSize) { + this.maxSeriesSize = maxSeriesSize; + } + + @Override + public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmToSqlAstConverter converter) { + if ( cteContainer.getCteStatement( CteGenerateSeriesFunction.CteGenerateSeriesQueryTransformer.NAME ) == null ) { + cteContainer.addCteStatement( CteGenerateSeriesFunction.CteGenerateSeriesQueryTransformer.createSeriesCte( maxSeriesSize, converter ) ); + } + return querySpec; + } + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableFunction.java index 4a2245069856..f252e24e616f 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableFunction.java @@ -62,7 +62,7 @@ protected JsonTableFunction(SetReturningFunctionTypeResolver setReturningFunctio "json_table", new ArgumentTypesValidator( StandardArgumentsValidators.between( 1, 2 ), - FunctionParameterType.JSON, + FunctionParameterType.IMPLICIT_JSON, FunctionParameterType.STRING ), setReturningFunctionTypeResolver, diff --git a/hibernate-core/src/main/java/org/hibernate/mapping/BasicValue.java b/hibernate-core/src/main/java/org/hibernate/mapping/BasicValue.java index 6a06eee0cb15..d90eb61e7337 100644 --- a/hibernate-core/src/main/java/org/hibernate/mapping/BasicValue.java +++ b/hibernate-core/src/main/java/org/hibernate/mapping/BasicValue.java @@ -971,7 +971,7 @@ public int resolveJdbcTypeCode(int jdbcTypeCode) { return aggregateColumn == null ? jdbcTypeCode : getDialect().getAggregateSupport() - .aggregateComponentSqlTypeCode( aggregateColumn.getSqlTypeCode( getMetadata() ), jdbcTypeCode ); + .aggregateComponentSqlTypeCode( aggregateColumn.getType().getJdbcType().getDefaultSqlTypeCode(), jdbcTypeCode ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcDateJavaType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcDateJavaType.java index 0581cf98dc56..4859d2c16b26 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcDateJavaType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcDateJavaType.java @@ -238,7 +238,7 @@ public Date fromEncodedString(CharSequence charSequence, int start, int end) { return java.sql.Date.valueOf( accessor.query( LocalDate::from ) ); } catch ( DateTimeParseException pe) { - throw new HibernateException( "could not parse time string " + charSequence, pe ); + throw new HibernateException( "could not parse time string " + subSequence( charSequence, start, end ), pe ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimeJavaType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimeJavaType.java index 6b4e5b0a7282..279c39c9af02 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimeJavaType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimeJavaType.java @@ -234,7 +234,7 @@ public Date fromEncodedString(CharSequence charSequence, int start, int end) { return java.sql.Time.valueOf( accessor.query( LocalTime::from ) ); } catch ( DateTimeParseException pe) { - throw new HibernateException( "could not parse time string " + charSequence, pe ); + throw new HibernateException( "could not parse time string " + subSequence( charSequence, start, end ), pe ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimestampJavaType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimestampJavaType.java index 1e33d20a460a..edbecdeaf42e 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimestampJavaType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/JdbcTimestampJavaType.java @@ -236,7 +236,7 @@ public Date fromEncodedString(CharSequence charSequence, int start, int end) { return timestamp; } catch ( DateTimeParseException pe) { - throw new HibernateException( "could not parse timestamp string " + charSequence, pe ); + throw new HibernateException( "could not parse timestamp string " + subSequence( charSequence, start, end ), pe ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/OffsetDateTimeJavaType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/OffsetDateTimeJavaType.java index bfb5748d2ab1..bccfd17dea9e 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/OffsetDateTimeJavaType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/OffsetDateTimeJavaType.java @@ -13,15 +13,16 @@ import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; +import java.time.format.DateTimeParseException; import java.time.temporal.ChronoField; import java.time.temporal.TemporalAccessor; import java.util.Calendar; import java.util.Date; import java.util.GregorianCalendar; +import org.hibernate.HibernateException; import org.hibernate.dialect.Dialect; import org.hibernate.engine.spi.SharedSessionContractImplementor; -import org.hibernate.internal.util.CharSequenceHelper; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.jdbc.JdbcType; @@ -32,6 +33,7 @@ import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME; import static java.time.format.DateTimeFormatter.ISO_OFFSET_DATE_TIME; +import static org.hibernate.internal.util.CharSequenceHelper.subSequence; /** * Java type descriptor for the {@link OffsetDateTime} type. @@ -95,16 +97,19 @@ public OffsetDateTime fromString(CharSequence string) { } @Override - public OffsetDateTime fromEncodedString(CharSequence string, int start, int end) { - final TemporalAccessor temporalAccessor = PARSE_FORMATTER.parse( - CharSequenceHelper.subSequence( string, start, end ) - ); - if ( temporalAccessor.isSupported( ChronoField.OFFSET_SECONDS ) ) { - return OffsetDateTime.from( temporalAccessor ); - } - else { - // For databases that don't have timezone support, we encode timestamps at UTC, so allow parsing that as well - return LocalDateTime.from( temporalAccessor ).atOffset( ZoneOffset.UTC ); + public OffsetDateTime fromEncodedString(CharSequence charSequence, int start, int end) { + try { + final TemporalAccessor temporalAccessor = PARSE_FORMATTER.parse( subSequence( charSequence, start, end ) ); + if ( temporalAccessor.isSupported( ChronoField.OFFSET_SECONDS ) ) { + return OffsetDateTime.from( temporalAccessor ); + } + else { + // For databases that don't have timezone support, we encode timestamps at UTC, so allow parsing that as well + return LocalDateTime.from( temporalAccessor ).atOffset( ZoneOffset.UTC ); + } + } + catch ( DateTimeParseException pe) { + throw new HibernateException( "could not parse timestamp string " + subSequence( charSequence, start, end ), pe ); } } From 1bb356e5c742928fb653c25777a54bcbc63cfabf Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Mon, 11 Nov 2024 17:35:25 +0100 Subject: [PATCH 5/7] HHH-18797 Add JSON aggregate support for HANA --- .../community/dialect/HANALegacyDialect.java | 7 + .../org/hibernate/dialect/HANADialect.java | 7 + .../aggregate/HANAAggregateSupport.java | 507 ++++++++++++++++++ .../function/array/DB2UnnestFunction.java | 7 +- .../function/array/HANAUnnestFunction.java | 46 +- .../array/SQLServerUnnestFunction.java | 13 +- .../array/SybaseASEUnnestFunction.java | 5 +- .../function/array/UnnestFunction.java | 29 +- .../json/HANAJsonObjectAggFunction.java | 9 +- .../function/json/HANAJsonValueFunction.java | 28 +- .../function/xml/HANAXmlTableFunction.java | 9 + 11 files changed, 620 insertions(+), 47 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/aggregate/HANAAggregateSupport.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/HANALegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/HANALegacyDialect.java index 8ce765bf103d..2e3a859b237c 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/HANALegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/HANALegacyDialect.java @@ -52,6 +52,8 @@ import org.hibernate.dialect.NullOrdering; import org.hibernate.dialect.OracleDialect; import org.hibernate.dialect.RowLockStrategy; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.HANAAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.IntegralTimestampaddFunction; import org.hibernate.dialect.identity.HANAIdentityColumnSupport; @@ -534,6 +536,11 @@ protected SqlAstTranslator buildTranslator( }; } + @Override + public AggregateSupport getAggregateSupport() { + return HANAAggregateSupport.valueOf( this ); + } + /** * HANA has no extract() function, but we can emulate * it using the appropriate named functions instead of diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/HANADialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/HANADialect.java index 76022a70ce7a..bf9bd39fc843 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/HANADialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/HANADialect.java @@ -13,6 +13,8 @@ import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.naming.Identifier; import org.hibernate.boot.model.relational.SqlStringGenerationContext; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.HANAAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.IntegralTimestampaddFunction; import org.hibernate.dialect.identity.HANAIdentityColumnSupport; @@ -536,6 +538,11 @@ protected SqlAstTranslator buildTranslator( }; } + @Override + public AggregateSupport getAggregateSupport() { + return HANAAggregateSupport.valueOf( this ); + } + /** * HANA has no extract() function, but we can emulate * it using the appropriate named functions instead of diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/HANAAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/HANAAggregateSupport.java new file mode 100644 index 000000000000..527d687bf0a3 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/HANAAggregateSupport.java @@ -0,0 +1,507 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.aggregate; + +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.function.json.HANAJsonValueFunction; +import org.hibernate.internal.util.StringHelper; +import org.hibernate.mapping.AggregateColumn; +import org.hibernate.mapping.Column; +import org.hibernate.metamodel.mapping.EmbeddableMappingType; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SelectableMapping; +import org.hibernate.metamodel.mapping.SelectablePath; +import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.sql.ast.SqlAstNodeRenderingMode; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.AggregateJdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.hibernate.type.SqlTypes.ARRAY; +import static org.hibernate.type.SqlTypes.BIGINT; +import static org.hibernate.type.SqlTypes.BINARY; +import static org.hibernate.type.SqlTypes.BLOB; +import static org.hibernate.type.SqlTypes.BOOLEAN; +import static org.hibernate.type.SqlTypes.DATE; +import static org.hibernate.type.SqlTypes.DECIMAL; +import static org.hibernate.type.SqlTypes.DOUBLE; +import static org.hibernate.type.SqlTypes.FLOAT; +import static org.hibernate.type.SqlTypes.INTEGER; +import static org.hibernate.type.SqlTypes.JSON; +import static org.hibernate.type.SqlTypes.JSON_ARRAY; +import static org.hibernate.type.SqlTypes.LONG32VARBINARY; +import static org.hibernate.type.SqlTypes.NUMERIC; +import static org.hibernate.type.SqlTypes.REAL; +import static org.hibernate.type.SqlTypes.SMALLINT; +import static org.hibernate.type.SqlTypes.TIME; +import static org.hibernate.type.SqlTypes.TIMESTAMP; +import static org.hibernate.type.SqlTypes.TIMESTAMP_UTC; +import static org.hibernate.type.SqlTypes.TINYINT; +import static org.hibernate.type.SqlTypes.UUID; +import static org.hibernate.type.SqlTypes.VARBINARY; + +public class HANAAggregateSupport extends AggregateSupportImpl { + + private static final AggregateSupport INSTANCE = new HANAAggregateSupport(); + + private static final String JSON_QUERY_START = "json_query("; + private static final String JSON_QUERY_JSON_END = "' error on error)"; + + private HANAAggregateSupport() { + } + + public static AggregateSupport valueOf(Dialect dialect) { + return dialect.getVersion().isSameOrAfter( 2, 0, 40 ) ? INSTANCE : AggregateSupportImpl.INSTANCE; + } + + @Override + public String aggregateComponentCustomReadExpression( + String template, + String placeholder, + String aggregateParentReadExpression, + String columnExpression, + int aggregateColumnTypeCode, + SqlTypedMapping column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + final String parentPartExpression = determineParentPartExpression( aggregateParentReadExpression ); + switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case BOOLEAN: + if ( SqlTypes.isNumericType( column.getJdbcMapping().getJdbcType().getDdlTypeCode() ) ) { + return template.replace( + placeholder, + "case json_value(" + parentPartExpression + columnExpression + "') when 'true' then 1 when 'false' then 0 end" + ); + } + else { + return template.replace( + placeholder, + "case json_value(" + parentPartExpression + columnExpression + "') when 'true' then true when 'false' then false end" + ); + } + case DATE: + case TIME: + case TIMESTAMP: + case TIMESTAMP_UTC: + return template.replace( + placeholder, + "cast(json_value(" + parentPartExpression + columnExpression + "') as " + column.getColumnDefinition() + ")" + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex, so we have to decode here + return template.replace( + placeholder, + "hextobin(json_value(" + parentPartExpression + columnExpression + "' error on error))" + ); + case JSON: + case JSON_ARRAY: + return template.replace( + placeholder, + "json_query(" + parentPartExpression + columnExpression + "' error on error)" + ); + case UUID: + if ( SqlTypes.isBinaryType( column.getJdbcMapping().getJdbcType().getDdlTypeCode() ) ) { + return template.replace( + placeholder, + "hextobin(json_value(" + parentPartExpression + columnExpression + "'))" + ); + } + // Fall-through intended + default: + return template.replace( + placeholder, + "json_value(" + parentPartExpression + columnExpression + "' returning " + HANAJsonValueFunction.jsonValueReturningType( + column ) + " error on error)" + ); + } + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private static String determineParentPartExpression(String aggregateParentReadExpression) { + final String parentPartExpression; + if ( aggregateParentReadExpression.startsWith( JSON_QUERY_START ) && aggregateParentReadExpression.endsWith( JSON_QUERY_JSON_END ) ) { + parentPartExpression = aggregateParentReadExpression.substring( JSON_QUERY_START.length(), aggregateParentReadExpression.length() - JSON_QUERY_JSON_END.length() ) + "."; + } + else { + parentPartExpression = aggregateParentReadExpression + ",'$."; + } + return parentPartExpression; + } + + private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) { + final int sqlTypeCode = jdbcMapping.getJdbcType().getDefaultSqlTypeCode(); + switch ( sqlTypeCode ) { + case UUID: + if ( !SqlTypes.isBinaryType( jdbcMapping.getJdbcType().getDdlTypeCode() ) ) { + return customWriteExpression; + } + // Fall-through intended + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex + return "bintohex(" + customWriteExpression + ")"; + case TIMESTAMP: + return "to_varchar(" + customWriteExpression + ",'YYYY-MM-DD\"T\"HH24:MI:SS.FF9')"; + case TIMESTAMP_UTC: + return "to_varchar(" + customWriteExpression + ",'YYYY-MM-DD\"T\"HH24:MI:SS.FF9\"Z\"')"; + default: + return customWriteExpression; + } + } + + @Override + public String aggregateComponentAssignmentExpression( + String aggregateParentAssignmentExpression, + String columnExpression, + int aggregateColumnTypeCode, + Column column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + // For JSON we always have to replace the whole object + return aggregateParentAssignmentExpression; + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + @Override + public String aggregateCustomWriteExpression( + AggregateColumn aggregateColumn, + List aggregatedColumns) { + // We need to know what array this is STRUCT_ARRAY/JSON_ARRAY/XML_ARRAY, + // which we can easily get from the type code of the aggregate column + final int sqlTypeCode = aggregateColumn.getType().getJdbcType().getDefaultSqlTypeCode(); + switch ( sqlTypeCode == SqlTypes.ARRAY ? aggregateColumn.getTypeCode() : sqlTypeCode ) { + case JSON: + case JSON_ARRAY: + return null; + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumn.getTypeCode() ); + } + + @Override + public int aggregateComponentSqlTypeCode(int aggregateColumnSqlTypeCode, int columnSqlTypeCode) { + if ( aggregateColumnSqlTypeCode == JSON ) { + return columnSqlTypeCode == ARRAY ? JSON_ARRAY : columnSqlTypeCode; + } + else { + return columnSqlTypeCode; + } + } + + @Override + public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { + return aggregateSqlTypeCode == JSON; + } + + @Override + public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer( + SelectableMapping aggregateColumn, + SelectableMapping[] columnsToUpdate, + TypeConfiguration typeConfiguration) { + final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + switch ( aggregateSqlTypeCode ) { + case JSON: + return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate ); + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode ); + } + + private WriteExpressionRenderer jsonAggregateColumnWriter( + SelectableMapping aggregateColumn, + SelectableMapping[] columns) { + return new RootJsonWriteExpression( aggregateColumn, columns ); + } + + interface JsonWriteExpression { + boolean isAggregate(); + void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression); + } + private static class AggregateJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String columnDefinition; + private final LinkedHashMap subExpressions = new LinkedHashMap<>(); + + private AggregateJsonWriteExpression(SelectableMapping selectableMapping, String columnDefinition) { + this.selectableMapping = selectableMapping; + this.columnDefinition = columnDefinition; + } + + @Override + public boolean isAggregate() { + return true; + } + + protected void initializeSubExpressions(SelectableMapping aggregateColumn, SelectableMapping[] columns) { + for ( SelectableMapping column : columns ) { + final SelectablePath selectablePath = column.getSelectablePath(); + final SelectablePath[] parts = selectablePath.getParts(); + AggregateJsonWriteExpression currentAggregate = this; + for ( int i = 1; i < parts.length - 1; i++ ) { + final AggregateJdbcType aggregateJdbcType = (AggregateJdbcType) currentAggregate.selectableMapping.getJdbcMapping().getJdbcType(); + final EmbeddableMappingType embeddableMappingType = aggregateJdbcType.getEmbeddableMappingType(); + final int selectableIndex = embeddableMappingType.getSelectableIndex( parts[i].getSelectableName() ); + currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent( + parts[i].getSelectableName(), + k -> new AggregateJsonWriteExpression( embeddableMappingType.getSelectable( selectableIndex ), columnDefinition ) + ); + } + final String customWriteExpression = column.getWriteExpression(); + currentAggregate.subExpressions.put( + parts[parts.length - 1].getSelectableName(), + new BasicJsonWriteExpression( + column, + jsonCustomWriteExpression( customWriteExpression, column.getJdbcMapping() ) + ) + ); + } + passThroughUnsetSubExpressions( aggregateColumn ); + } + + protected void passThroughUnsetSubExpressions(SelectableMapping aggregateColumn) { + final AggregateJdbcType aggregateJdbcType = (AggregateJdbcType) aggregateColumn.getJdbcMapping().getJdbcType(); + final EmbeddableMappingType embeddableMappingType = aggregateJdbcType.getEmbeddableMappingType(); + final int jdbcValueCount = embeddableMappingType.getJdbcValueCount(); + for ( int i = 0; i < jdbcValueCount; i++ ) { + final SelectableMapping selectableMapping = embeddableMappingType.getJdbcValueSelectable( i ); + + final JsonWriteExpression jsonWriteExpression = subExpressions.get( selectableMapping.getSelectableName() ); + if ( jsonWriteExpression == null ) { + subExpressions.put( + selectableMapping.getSelectableName(), + new PassThroughExpression( selectableMapping ) + ); + } + else if ( jsonWriteExpression instanceof AggregateJsonWriteExpression writeExpression ) { + writeExpression.passThroughUnsetSubExpressions( selectableMapping ); + } + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + final int aggregateCount = determineAggregateCount(); + if ( aggregateCount != 0 ) { + sb.append( "(trim(trailing '}' from " ); + } + + sb.append( "(select" ); + if ( aggregateCount != subExpressions.size() ) { + char separator = ' '; + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + if ( !value.isAggregate() ) { + sb.append( separator ); + value.append( sb, path, translator, expression ); + sb.append( ' ' ); + sb.appendDoubleQuoteEscapedString( column ); + separator = ','; + } + } + sb.append( " from sys.dummy for json('arraywrap'='no','omitnull'='no')" ); + sb.append( " returns " ); + sb.append( columnDefinition ); + } + else { + sb.append( " cast('{}' as " ); + sb.append( columnDefinition ); + sb.append( ") jsonresult from sys.dummy" ); + } + sb.append( ')' ); + if ( aggregateCount != 0 ) { + sb.append( ')' ); + final String parentPartExpression = determineParentPartExpression( path ); + String separator = aggregateCount == subExpressions.size() ? " " : ","; + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + if ( value.isAggregate() ) { + sb.append( "||'" ); + sb.append( separator ); + sb.append( '"' ); + sb.append( column ); + sb.append( "\":'||" ); + if ( value instanceof AggregateJsonWriteExpression ) { + final String subPath = "json_query(" + parentPartExpression + column + "' error on error)"; + value.append( sb, subPath, translator, expression ); + } + else { + sb.append( "coalesce(" ); + value.append( sb, path, translator, expression ); + sb.append( ",'null')" ); + } + separator = ","; + } + } + sb.append( "||'}')" ); + } + } + + private int determineAggregateCount() { + int count = 0; + for ( Map.Entry entry : subExpressions.entrySet() ) { + if ( entry.getValue().isAggregate() ) { + count++; + } + } + return count; + } + } + + private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + implements WriteExpressionRenderer { + private final String path; + + RootJsonWriteExpression(SelectableMapping aggregateColumn, SelectableMapping[] columns) { + super( aggregateColumn, aggregateColumn.getColumnDefinition() ); + path = aggregateColumn.getSelectionExpression(); + initializeSubExpressions( aggregateColumn, columns ); + } + + @Override + public void render( + SqlAppender sqlAppender, + SqlAstTranslator translator, + AggregateColumnWriteExpression aggregateColumnWriteExpression, + String qualifier) { + final String basePath; + if ( qualifier == null || qualifier.isBlank() ) { + basePath = path; + } + else { + basePath = qualifier + "." + path; + } + append( sqlAppender, basePath, translator, aggregateColumnWriteExpression ); + } + } + private static class BasicJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String customWriteExpressionStart; + private final String customWriteExpressionEnd; + + BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) { + this.selectableMapping = selectableMapping; + if ( customWriteExpression.equals( "?" ) ) { + this.customWriteExpressionStart = ""; + this.customWriteExpressionEnd = ""; + } + else { + final String[] parts = StringHelper.split( "?", customWriteExpression ); + assert parts.length == 2; + this.customWriteExpressionStart = parts[0]; + this.customWriteExpressionEnd = parts[1]; + } + } + + @Override + public boolean isAggregate() { + return selectableMapping.getJdbcMapping().getJdbcType().isJson(); + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( customWriteExpressionStart ); + // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, + // since we don't know how the custom write expression looks like where this is embedded, + // so we have to be pessimistic and avoid ambiguities + translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + sb.append( customWriteExpressionEnd ); + } + } + + private static class PassThroughExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + + PassThroughExpression(SelectableMapping selectableMapping) { + this.selectableMapping = selectableMapping; + } + + @Override + public boolean isAggregate() { + return selectableMapping.getJdbcMapping().getJdbcType().isJson(); + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + final String parentPartExpression = determineParentPartExpression( path ); + switch ( selectableMapping.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case BOOLEAN: + sb.append( "case json_value(" ); + sb.append( parentPartExpression ); + sb.append( selectableMapping.getSelectableName() ); + if ( SqlTypes.isNumericType( selectableMapping.getJdbcMapping().getJdbcType().getDdlTypeCode() ) ) { + sb.append( "') when 'true' then 1 when 'false' then 0 end" ); + } + else { + sb.append( "') when 'true' then true when 'false' then false end" ); + } + break; + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case FLOAT: + case REAL: + case DOUBLE: + case DECIMAL: + case NUMERIC: + sb.append( "json_value(" ); + sb.append( parentPartExpression ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "' returning " ); + sb.append( HANAJsonValueFunction.jsonValueReturningType( selectableMapping ) ); + sb.append( " error on error)" ); + break; + case JSON: + case JSON_ARRAY: + sb.append( "json_query(" ); + sb.append( parentPartExpression ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "' error on error)" ); + break; + default: + sb.append( "json_value(" ); + sb.append( parentPartExpression ); + sb.append( selectableMapping.getSelectableName() ); + sb.append( "' error on error)" ); + break; + } + } + } + +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java index 9ccc8817192c..af2af3130e92 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/DB2UnnestFunction.java @@ -25,6 +25,7 @@ import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.sql.ast.tree.from.TableGroup; import org.hibernate.type.BasicPluralType; +import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; @@ -106,7 +107,7 @@ protected void renderJsonTable( } else { sqlAppender.appendSql( " returning " ); - sqlAppender.append( getDdlType( elementMapping, walker ) ); + sqlAppender.append( getDdlType( elementMapping, SqlTypes.JSON_ARRAY, walker ) ); sqlAppender.append( ") " ); } @@ -136,10 +137,10 @@ protected void renderJsonTable( } sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); sqlAppender.appendSql( " path '$." ); sqlAppender.append( selectableMapping.getSelectableName() ); - sqlAppender.appendSql( '\'' ); + sqlAppender.appendSql( "' error on error" ); } } ); sqlAppender.appendSql( ") error on error) t on json_exists('{\"a\":'||" ); diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/HANAUnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/HANAUnnestFunction.java index 4573b35bf561..753a561b4f72 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/HANAUnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/HANAUnnestFunction.java @@ -11,7 +11,7 @@ import org.hibernate.QueryException; import org.hibernate.dialect.XmlHelper; import org.hibernate.dialect.function.json.ExpressionTypeHelper; -import org.hibernate.engine.jdbc.Size; +import org.hibernate.dialect.function.json.HANAJsonValueFunction; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.CollectionPart; import org.hibernate.metamodel.mapping.EmbeddableValuedModelPart; @@ -54,7 +54,7 @@ import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.results.internal.SqlSelectionImpl; import org.hibernate.type.BasicPluralType; -import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; import org.hibernate.type.Type; import org.hibernate.type.descriptor.java.BasicPluralJavaType; import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; @@ -74,7 +74,6 @@ public HANAUnnestFunction() { protected SelfRenderingSqmSetReturningFunction generateSqmSetReturningFunctionExpression( List> arguments, QueryEngine queryEngine) { - //noinspection unchecked return new SelfRenderingSqmSetReturningFunction<>( this, this, @@ -357,7 +356,7 @@ protected void renderXmlTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( " path '" ); sqlAppender.appendSql( selectableMapping.getSelectableName() ); sqlAppender.appendSql( "'" ); @@ -378,7 +377,7 @@ protected void renderXmlTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( " path '" ); sqlAppender.appendSql( "." ); sqlAppender.appendSql( "'" ); @@ -445,6 +444,15 @@ public JdbcMappingContainer getExpressionType() { } } + @Override + protected String getDdlType(SqlTypedMapping sqlTypedMapping, int containerSqlTypeCode, SqlAstTranslator translator) { + final String ddlType = super.getDdlType( sqlTypedMapping, containerSqlTypeCode, translator ); + if ( containerSqlTypeCode == SqlTypes.JSON_ARRAY ) { + return HANAJsonValueFunction.jsonValueReturningType( ddlType ); + } + return ddlType; + } + @Override protected void renderJsonTable( SqlAppender sqlAppender, @@ -454,12 +462,6 @@ protected void renderJsonTable( AnonymousTupleTableGroupProducer tupleType, String tableIdentifierVariable, SqlAstTranslator walker) { - final BasicType elementType = pluralType.getElementType(); - final String columnType = walker.getSessionFactory().getTypeConfiguration().getDdlTypeRegistry().getTypeName( - elementType.getJdbcType().getDdlTypeCode(), - sqlTypedMapping == null ? Size.nil() : sqlTypedMapping.toSize(), - elementType - ); sqlAppender.appendSql( "json_table(" ); array.accept( walker ); @@ -474,18 +476,14 @@ protected void renderJsonTable( sqlAppender.appendSql( "'," ); } - sqlAppender.appendSql( "nested path '$.v' columns (" ); - sqlAppender.append( tupleType.getColumnNames().get( 0 ) ); - sqlAppender.appendSql( ' ' ); - sqlAppender.append( columnType ); - sqlAppender.appendSql( " path '$')))" ); + sqlAppender.appendSql( "nested path '$.v' columns" ); + renderJsonTableColumns( sqlAppender, tupleType, walker, true ); + sqlAppender.appendSql( "))" ); } else { - sqlAppender.appendSql( ",'$[*]' columns(" ); - sqlAppender.append( tupleType.getColumnNames().get( 0 ) ); - sqlAppender.appendSql( ' ' ); - sqlAppender.append( columnType ); - sqlAppender.appendSql( " path '$'))" ); + sqlAppender.appendSql( ",'$[*]' columns" ); + renderJsonTableColumns( sqlAppender, tupleType, walker, true ); + sqlAppender.appendSql( ")" ); } } @@ -519,9 +517,11 @@ public void renderToSql( separator = ','; } sqlAppender.appendSql( " from sys.dummy for json('arraywrap'='no')))||" ); - sqlAppender.appendSql( "',\"v\":'||" ); + sqlAppender.appendSql( "',\"v\":'||case when " ); + argument.accept( walker ); + sqlAppender.appendSql( " not like '[]' then " ); argument.accept( walker ); - sqlAppender.appendSql( "||'}'" ); + sqlAppender.appendSql( " end||'}'" ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java index 1e908e4eb2eb..fbd0d6bdadd5 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java @@ -13,6 +13,7 @@ import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.type.BasicPluralType; +import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.java.BasicPluralJavaType; import org.checkerframework.checker.nullness.qual.Nullable; @@ -60,7 +61,7 @@ protected void renderJsonTable( } sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); sqlAppender.appendSql( " path '$." ); sqlAppender.append( selectableMapping.getSelectableName() ); sqlAppender.appendSql( '\'' ); @@ -79,7 +80,7 @@ protected void renderJsonTable( } sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); sqlAppender.appendSql( " path '$'" ); } } ); @@ -120,7 +121,7 @@ protected void renderXmlTable( sqlAppender.appendSql( "t.v.value('count(for $a in . return $a/../" ); sqlAppender.appendSql( collectionTags.elementName() ); sqlAppender.appendSql( "[.<<$a])+1','" ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( "') " ); sqlAppender.appendSql( selectableMapping.getSelectionExpression() ); } @@ -128,7 +129,7 @@ protected void renderXmlTable( sqlAppender.appendSql( "t.v.value('"); sqlAppender.appendSql( selectableMapping.getSelectableName() ); sqlAppender.appendSql( "/text()[1]','" ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( "') " ); sqlAppender.appendSql( selectableMapping.getSelectionExpression() ); } @@ -146,13 +147,13 @@ protected void renderXmlTable( sqlAppender.appendSql( "t.v.value('count(for $a in . return $a/../" ); sqlAppender.appendSql( collectionTags.elementName() ); sqlAppender.appendSql( "[.<<$a])+1','" ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( "') " ); sqlAppender.appendSql( selectableMapping.getSelectionExpression() ); } else { sqlAppender.appendSql( "t.v.value('text()[1]','" ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( "') " ); sqlAppender.appendSql( selectableMapping.getSelectionExpression() ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SybaseASEUnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SybaseASEUnnestFunction.java index ff7da74a94d2..d549c3b5f919 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SybaseASEUnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SybaseASEUnnestFunction.java @@ -12,6 +12,7 @@ import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.Expression; import org.hibernate.type.BasicPluralType; +import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.java.BasicPluralJavaType; import org.checkerframework.checker.nullness.qual.Nullable; @@ -59,7 +60,7 @@ protected void renderXmlTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( " path '" ); sqlAppender.appendSql( selectableMapping.getSelectableName() ); sqlAppender.appendSql( "'" ); @@ -80,7 +81,7 @@ protected void renderXmlTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( " path '" ); sqlAppender.appendSql( "." ); sqlAppender.appendSql( "'" ); diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/UnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/UnnestFunction.java index 9dfc06570acd..f70e4bf6f566 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/UnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/UnnestFunction.java @@ -66,7 +66,7 @@ else if ( ddlTypeCode == SqlTypes.XML_ARRAY ) { } } - protected String getDdlType(SqlTypedMapping sqlTypedMapping, SqlAstTranslator translator) { + protected String getDdlType(SqlTypedMapping sqlTypedMapping, int containerSqlTypeCode, SqlAstTranslator translator) { final String columnDefinition = sqlTypedMapping.getColumnDefinition(); if ( columnDefinition != null ) { return columnDefinition; @@ -88,11 +88,16 @@ protected void renderJsonTable( SqlAstTranslator walker) { sqlAppender.appendSql( "json_table(" ); array.accept( walker ); - sqlAppender.appendSql( ",'$[*]' columns(" ); + sqlAppender.appendSql( ",'$[*]' columns" ); + renderJsonTableColumns( sqlAppender, tupleType, walker, false ); + sqlAppender.appendSql( ')' ); + } + + protected void renderJsonTableColumns(SqlAppender sqlAppender, AnonymousTupleTableGroupProducer tupleType, SqlAstTranslator walker, boolean errorOnError) { if ( tupleType.findSubPart( CollectionPart.Nature.ELEMENT.getName(), null ) == null ) { tupleType.forEachSelectable( 0, (selectionIndex, selectableMapping) -> { if ( selectionIndex == 0 ) { - sqlAppender.append( ' ' ); + sqlAppender.append( '(' ); } else { sqlAppender.append( ',' ); @@ -103,17 +108,20 @@ protected void renderJsonTable( sqlAppender.append( " for ordinality" ); } else { - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); sqlAppender.appendSql( " path '$." ); sqlAppender.append( selectableMapping.getSelectableName() ); sqlAppender.appendSql( '\'' ); + if ( errorOnError ) { + sqlAppender.appendSql( " error on error" ); + } } } ); } else { tupleType.forEachSelectable( 0, (selectionIndex, selectableMapping) -> { if ( selectionIndex == 0 ) { - sqlAppender.append( ' ' ); + sqlAppender.append( '(' ); } else { sqlAppender.append( ',' ); @@ -124,12 +132,15 @@ protected void renderJsonTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); sqlAppender.appendSql( " path '$'" ); + if ( errorOnError ) { + sqlAppender.appendSql( " error on error" ); + } } } ); } - sqlAppender.appendSql( "))" ); + sqlAppender.appendSql( ')' ); } protected void renderXmlTable( @@ -165,7 +176,7 @@ protected void renderXmlTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( " path '" ); sqlAppender.appendSql( selectableMapping.getSelectableName() ); sqlAppender.appendSql( "/text()" ); @@ -187,7 +198,7 @@ protected void renderXmlTable( } else { sqlAppender.append( ' ' ); - sqlAppender.append( getDdlType( selectableMapping, walker ) ); + sqlAppender.append( getDdlType( selectableMapping, SqlTypes.XML_ARRAY, walker ) ); sqlAppender.appendSql( " path '" ); sqlAppender.appendSql( "text()" ); sqlAppender.appendSql( "'" ); diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonObjectAggFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonObjectAggFunction.java index 3dba7b8f0bf5..d56d44580139 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonObjectAggFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonObjectAggFunction.java @@ -37,7 +37,7 @@ protected void render( throw new QueryException( "Can't emulate json_objectagg 'with unique keys' clause." ); } sqlAppender.appendSql( "'{'||string_agg(" ); - renderArgument( sqlAppender, arguments.key(), arguments.nullBehavior(), translator ); + renderArgument( sqlAppender, arguments.key(), JsonNullBehavior.NULL, translator ); sqlAppender.appendSql( "||':'||" ); if ( caseWrapper ) { if ( arguments.nullBehavior() != JsonNullBehavior.ABSENT ) { @@ -76,8 +76,11 @@ protected void renderArgument( } sqlAppender.appendSql( "json_query((select " ); arg.accept( translator ); - sqlAppender.appendSql( - " V from sys.dummy for json('arraywrap'='no','omitnull'='no') returns nvarchar(" + Integer.MAX_VALUE + ")),'$.V')" ); + sqlAppender.appendSql( " V from sys.dummy for json('arraywrap'='no'" ); + if ( nullBehavior != JsonNullBehavior.NULL ) { + sqlAppender.appendSql( ",'omitnull'='no'" ); + } + sqlAppender.appendSql( ") returns nvarchar(" + Integer.MAX_VALUE + ")),'$.V')" ); if ( nullBehavior != JsonNullBehavior.NULL ) { sqlAppender.appendSql( ",'null')" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonValueFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonValueFunction.java index aaaa77a54862..2fb7f3114d7f 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonValueFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/HANAJsonValueFunction.java @@ -7,6 +7,7 @@ import org.hibernate.dialect.Dialect; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SqlTypedMapping; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; @@ -14,6 +15,8 @@ import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; import org.hibernate.type.spi.TypeConfiguration; +import static org.hibernate.sql.ast.spi.AbstractSqlAstTranslator.getCastTypeName; + /** * HANA json_value function. */ @@ -52,11 +55,34 @@ protected void render( } } + public static String jsonValueReturningType(SqlTypedMapping column) { + final String columnDefinition = column.getColumnDefinition(); + assert columnDefinition != null; + return jsonValueReturningType( columnDefinition ); + } + + public static String jsonValueReturningType(String columnDefinition) { + final int parenthesisIndex = columnDefinition.indexOf( '(' ); + final String baseName = parenthesisIndex == -1 + ? columnDefinition + : columnDefinition.substring( 0, parenthesisIndex ); + return switch ( baseName ) { + case "real", "float", "double", "decimal" -> "decimal"; + case "tinyint", "smallint" -> "integer"; + case "clob" -> "varchar(5000)"; + case "nclob" -> "nvarchar(5000)"; + default -> columnDefinition; + }; + } + @Override protected void renderReturningClause(SqlAppender sqlAppender, JsonValueArguments arguments, SqlAstTranslator walker) { // No return type for booleans, this is handled via decode if ( arguments.returningType() != null && !isEncodedBoolean( arguments.returningType().getJdbcMapping() ) ) { - super.renderReturningClause( sqlAppender, arguments, walker ); + sqlAppender.appendSql( " returning " ); + sqlAppender.appendSql( jsonValueReturningType( + getCastTypeName( arguments.returningType(), walker.getSessionFactory().getTypeConfiguration() ) + ) ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java index 4dfe899dcac1..fe5810fd2a62 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java @@ -4,6 +4,7 @@ */ package org.hibernate.dialect.function.xml; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.QueryException; import org.hibernate.dialect.Dialect; import org.hibernate.engine.spi.SessionFactoryImplementor; @@ -29,6 +30,7 @@ import org.hibernate.query.sqm.tree.expression.SqmXmlTableFunction; import org.hibernate.spi.NavigablePath; import org.hibernate.sql.Template; +import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.internal.ColumnQualifierCollectorSqlAstWalker; import org.hibernate.sql.ast.spi.FromClauseAccess; @@ -408,6 +410,13 @@ protected void renderXmlQueryColumnDefinition(SqlAppender sqlAppender, XmlTableQ renderDefaultExpression( definition.defaultExpression(), sqlAppender, walker ); } + protected void renderDefaultExpression(@Nullable Expression expression, SqlAppender sqlAppender, SqlAstTranslator walker) { + if ( expression != null ) { + sqlAppender.appendSql( " default " ); + sqlAppender.appendSingleQuoteEscapedString( walker.getLiteralValue( expression ) ); + } + } + static boolean isBoolean(JdbcMapping type) { return type.getJdbcType().isBoolean(); } From d7b7abbbb1119e547110ad90bda87b400c61bf9f Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Mon, 11 Nov 2024 22:02:07 +0100 Subject: [PATCH 6/7] HHH-18798 Add JSON aggregate support for SQL Server --- .../dialect/SQLServerLegacyDialect.java | 7 + .../hibernate/dialect/SQLServerDialect.java | 7 + .../aggregate/SQLServerAggregateSupport.java | 401 ++++++++++++++++++ .../array/SQLServerUnnestFunction.java | 8 +- .../json/SQLServerJsonTableFunction.java | 29 +- 5 files changed, 432 insertions(+), 20 deletions(-) create mode 100644 hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java index 2d98bb6aacc1..978cc6a747fe 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/SQLServerLegacyDialect.java @@ -19,6 +19,8 @@ import org.hibernate.dialect.SQLServerCastingXmlArrayJdbcTypeConstructor; import org.hibernate.dialect.SQLServerCastingXmlJdbcType; import org.hibernate.dialect.TimeZoneSupport; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.SQLServerAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.CountFunction; import org.hibernate.dialect.function.SQLServerFormatEmulation; @@ -504,6 +506,11 @@ protected SqlAstTranslator buildTranslator( }; } + @Override + public AggregateSupport getAggregateSupport() { + return SQLServerAggregateSupport.valueOf( this ); + } + @Override public SizeStrategy getSizeStrategy() { return sizeStrategy; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java index 16019bb4e84e..8d4a950b621d 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/SQLServerDialect.java @@ -25,6 +25,8 @@ import org.hibernate.boot.model.relational.QualifiedSequenceName; import org.hibernate.boot.model.relational.Sequence; import org.hibernate.boot.model.relational.SqlStringGenerationContext; +import org.hibernate.dialect.aggregate.AggregateSupport; +import org.hibernate.dialect.aggregate.SQLServerAggregateSupport; import org.hibernate.dialect.function.CommonFunctionFactory; import org.hibernate.dialect.function.CountFunction; import org.hibernate.dialect.function.SQLServerFormatEmulation; @@ -511,6 +513,11 @@ protected SqlAstTranslator buildTranslator( }; } + @Override + public AggregateSupport getAggregateSupport() { + return SQLServerAggregateSupport.valueOf( this ); + } + @Override public SizeStrategy getSizeStrategy() { return sizeStrategy; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java new file mode 100644 index 000000000000..ab065459ecec --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/SQLServerAggregateSupport.java @@ -0,0 +1,401 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.aggregate; + +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.internal.util.StringHelper; +import org.hibernate.mapping.Column; +import org.hibernate.metamodel.mapping.EmbeddableMappingType; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SelectableMapping; +import org.hibernate.metamodel.mapping.SelectablePath; +import org.hibernate.metamodel.mapping.SqlTypedMapping; +import org.hibernate.sql.ast.SqlAstNodeRenderingMode; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.BasicPluralType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.AggregateJdbcType; +import org.hibernate.type.descriptor.sql.DdlType; +import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.hibernate.type.SqlTypes.*; + +public class SQLServerAggregateSupport extends AggregateSupportImpl { + + private static final AggregateSupport INSTANCE = new SQLServerAggregateSupport(); + + private static final String JSON_QUERY_START = "json_query("; + private static final String JSON_QUERY_JSON_END = "')"; + private static final int JSON_VALUE_MAX_LENGTH = 4000; + + private SQLServerAggregateSupport() { + } + + public static AggregateSupport valueOf(Dialect dialect) { + return dialect.getVersion().isSameOrAfter( 13 ) + ? SQLServerAggregateSupport.INSTANCE + : AggregateSupportImpl.INSTANCE; + } + + @Override + public String aggregateComponentCustomReadExpression( + String template, + String placeholder, + String aggregateParentReadExpression, + String columnExpression, + int aggregateColumnTypeCode, + SqlTypedMapping column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + final String parentPartExpression; + if ( aggregateParentReadExpression.startsWith( JSON_QUERY_START ) + && aggregateParentReadExpression.endsWith( JSON_QUERY_JSON_END ) ) { + parentPartExpression = aggregateParentReadExpression.substring( JSON_QUERY_START.length(), aggregateParentReadExpression.length() - JSON_QUERY_JSON_END.length() ) + "."; + } + else { + parentPartExpression = aggregateParentReadExpression + ",'$."; + } + switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case JSON: + case JSON_ARRAY: + return template.replace( + placeholder, + "json_query(" + parentPartExpression + columnExpression + "')" + ); + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + // We encode binary data as hex, so we have to decode here + if ( determineLength( column ) * 2 > JSON_VALUE_MAX_LENGTH ) { + // Since data is HEX encoded, multiply the max length by 2 since we need 2 hex chars per byte + return template.replace( + placeholder, + "(select convert(" + column.getColumnDefinition() + ",v,2) from openjson(" + aggregateParentReadExpression + ") with (v varchar(max) '$." + columnExpression + "'))" + ); + } + else { + return template.replace( + placeholder, + "convert(" + column.getColumnDefinition() + ",json_value(" + parentPartExpression + columnExpression + "'),2)" + ); + } + case CHAR: + case NCHAR: + case VARCHAR: + case NVARCHAR: + case LONG32VARCHAR: + case LONG32NVARCHAR: + case CLOB: + case NCLOB: + if ( determineLength( column ) > JSON_VALUE_MAX_LENGTH ) { + return template.replace( + placeholder, + "(select * from openjson(" + aggregateParentReadExpression + ") with (v " + column.getColumnDefinition() + " '$." + columnExpression + "'))" + ); + } + // Fall-through intended + case BIT: + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case REAL: + case FLOAT: + case DOUBLE: + case NUMERIC: + case DECIMAL: + case TIME: + case TIME_UTC: + case TIME_WITH_TIMEZONE: + case DATE: + case TIMESTAMP: + case TIMESTAMP_UTC: + case TIMESTAMP_WITH_TIMEZONE: + return template.replace( + placeholder, + "cast(json_value(" + parentPartExpression + columnExpression + "') as " + column.getColumnDefinition() + ")" + ); + default: + return template.replace( + placeholder, + "(select * from openjson(" + aggregateParentReadExpression + ") with (v " + column.getColumnDefinition() + " '$." + columnExpression + "'))" + ); + } + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private static Long determineLength(SqlTypedMapping column) { + final Long length = column.getLength(); + if ( length != null ) { + return length; + } + else { + final String columnDefinition = column.getColumnDefinition(); + assert columnDefinition != null; + final int parenthesisIndex = columnDefinition.indexOf( '(' ); + if ( parenthesisIndex != -1 ) { + int end; + for ( end = parenthesisIndex + 1; end < columnDefinition.length(); end++ ) { + if ( !Character.isDigit( columnDefinition.charAt( end ) ) ) { + break; + } + } + return Long.parseLong( columnDefinition.substring( parenthesisIndex + 1, end ) ); + } + // Default to the max varchar length + return 8000L; + } + } + + @Override + public String aggregateComponentAssignmentExpression( + String aggregateParentAssignmentExpression, + String columnExpression, + int aggregateColumnTypeCode, + Column column) { + switch ( aggregateColumnTypeCode ) { + case JSON: + case JSON_ARRAY: + // For JSON we always have to replace the whole object + return aggregateParentAssignmentExpression; + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode ); + } + + private String jsonCustomWriteExpression( + String customWriteExpression, + JdbcMapping jdbcMapping, + SelectableMapping column, + TypeConfiguration typeConfiguration) { + switch ( jdbcMapping.getJdbcType().getDefaultSqlTypeCode() ) { + case BINARY: + case VARBINARY: + case LONG32VARBINARY: + case BLOB: + return "convert(nvarchar(max)," + customWriteExpression + ",2)"; + case TIME: + return "left(" + customWriteExpression + ",8)"; + case DATE: + return "format(" + customWriteExpression + ",'yyyy-MM-dd')"; + case TIMESTAMP: + return "format(" + customWriteExpression + ",'yyyy-MM-ddTHH:mm:ss.fffffff')"; + case TIMESTAMP_UTC: + case TIMESTAMP_WITH_TIMEZONE: + return "format(" + customWriteExpression + ",'yyyy-MM-ddTHH:mm:ss.fffffffzzz')"; + case UUID: + return "cast(" + customWriteExpression + " as nvarchar(36))"; + case JSON: + case JSON_ARRAY: + return "json_query(" + customWriteExpression + ")"; + default: + return customWriteExpression; + } + } + + private static String determineElementTypeName( + Size castTargetSize, + BasicPluralType pluralType, + TypeConfiguration typeConfiguration) { + final DdlTypeRegistry ddlTypeRegistry = typeConfiguration.getDdlTypeRegistry(); + final BasicType expressionType = pluralType.getElementType(); + DdlType ddlType = ddlTypeRegistry.getDescriptor( expressionType.getJdbcType().getDdlTypeCode() ); + if ( ddlType == null ) { + // this may happen when selecting a null value like `SELECT null from ...` + // some dbs need the value to be cast so not knowing the real type we fall back to INTEGER + ddlType = ddlTypeRegistry.getDescriptor( SqlTypes.INTEGER ); + } + + return ddlType.getTypeName( castTargetSize, expressionType, ddlTypeRegistry ); + } + + @Override + public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { + return aggregateSqlTypeCode == JSON; + } + + @Override + public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer( + SelectableMapping aggregateColumn, + SelectableMapping[] columnsToUpdate, + TypeConfiguration typeConfiguration) { + final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + switch ( aggregateSqlTypeCode ) { + case JSON: + return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate, typeConfiguration ); + } + throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode ); + } + + private WriteExpressionRenderer jsonAggregateColumnWriter( + SelectableMapping aggregateColumn, + SelectableMapping[] columns, + TypeConfiguration typeConfiguration) { + return new RootJsonWriteExpression( aggregateColumn, columns, this, typeConfiguration ); + } + + interface JsonWriteExpression { + void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression); + } + private static class AggregateJsonWriteExpression implements JsonWriteExpression { + private final LinkedHashMap subExpressions = new LinkedHashMap<>(); + protected final EmbeddableMappingType embeddableMappingType; + + public AggregateJsonWriteExpression(SelectableMapping selectableMapping, SQLServerAggregateSupport aggregateSupport) { + this.embeddableMappingType = ( (AggregateJdbcType) selectableMapping.getJdbcMapping().getJdbcType() ) + .getEmbeddableMappingType(); + } + + protected void initializeSubExpressions( + SelectableMapping[] columns, + SQLServerAggregateSupport aggregateSupport, + TypeConfiguration typeConfiguration) { + for ( SelectableMapping column : columns ) { + final SelectablePath selectablePath = column.getSelectablePath(); + final SelectablePath[] parts = selectablePath.getParts(); + AggregateJsonWriteExpression currentAggregate = this; + EmbeddableMappingType currentMappingType = embeddableMappingType; + for ( int i = 1; i < parts.length - 1; i++ ) { + final SelectableMapping selectableMapping = currentMappingType.getJdbcValueSelectable( + currentMappingType.getSelectableIndex( parts[i].getSelectableName() ) + ); + currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent( + parts[i].getSelectableName(), + k -> new AggregateJsonWriteExpression( selectableMapping, aggregateSupport ) + ); + currentMappingType = currentAggregate.embeddableMappingType; + } + final String customWriteExpression = column.getWriteExpression(); + currentAggregate.subExpressions.put( + parts[parts.length - 1].getSelectableName(), + new BasicJsonWriteExpression( + column, + aggregateSupport.jsonCustomWriteExpression( + customWriteExpression, + column.getJdbcMapping(), + column, + typeConfiguration + ) + ) + ); + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + for ( int i = 0; i < subExpressions.size() - 1; i++ ) { + sb.append( "json_modify(" ); + } + sb.append( "json_modify(" ); + sb.append( path ); + for ( Map.Entry entry : subExpressions.entrySet() ) { + final String column = entry.getKey(); + final JsonWriteExpression value = entry.getValue(); + final String subPath = "json_query(" + path + ",'$." + column + "')"; + sb.append( ",'$." ); + sb.append( column ); + sb.append( "'," ); + if ( value instanceof AggregateJsonWriteExpression ) { + value.append( sb, subPath, translator, expression ); + } + else { + value.append( sb, subPath, translator, expression ); + } + sb.append( ')' ); + } + } + } + + private static class RootJsonWriteExpression extends AggregateJsonWriteExpression + implements WriteExpressionRenderer { + private final boolean nullable; + private final String path; + + RootJsonWriteExpression( + SelectableMapping aggregateColumn, + SelectableMapping[] columns, + SQLServerAggregateSupport aggregateSupport, + TypeConfiguration typeConfiguration) { + super( aggregateColumn, aggregateSupport ); + this.nullable = aggregateColumn.isNullable(); + this.path = aggregateColumn.getSelectionExpression(); + initializeSubExpressions( columns, aggregateSupport, typeConfiguration ); + } + + @Override + public void render( + SqlAppender sqlAppender, + SqlAstTranslator translator, + AggregateColumnWriteExpression aggregateColumnWriteExpression, + String qualifier) { + final String basePath; + if ( qualifier == null || qualifier.isBlank() ) { + basePath = path; + } + else { + basePath = qualifier + "." + path; + } + append( + sqlAppender, + nullable ? "coalesce(" + basePath + ",'{}')" : basePath, + translator, + aggregateColumnWriteExpression + ); + } + } + + private static class BasicJsonWriteExpression implements JsonWriteExpression { + + private final SelectableMapping selectableMapping; + private final String customWriteExpressionStart; + private final String customWriteExpressionEnd; + + BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) { + this.selectableMapping = selectableMapping; + if ( customWriteExpression.equals( "?" ) ) { + this.customWriteExpressionStart = ""; + this.customWriteExpressionEnd = ""; + } + else { + final String[] parts = StringHelper.split( "?", customWriteExpression ); + assert parts.length == 2; + this.customWriteExpressionStart = parts[0]; + this.customWriteExpressionEnd = parts[1]; + } + } + + @Override + public void append( + SqlAppender sb, + String path, + SqlAstTranslator translator, + AggregateColumnWriteExpression expression) { + sb.append( customWriteExpressionStart ); + // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, + // since we don't know how the custom write expression looks like where this is embedded, + // so we have to be pessimistic and avoid ambiguities + translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + sb.append( customWriteExpressionEnd ); + } + } + +} diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java index fbd0d6bdadd5..989d88ac9d07 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/SQLServerUnnestFunction.java @@ -39,14 +39,14 @@ protected void renderJsonTable( final ModelPart ordinalityPart = tupleType.findSubPart( CollectionPart.Nature.INDEX.getName(), null ); if ( ordinalityPart != null ) { sqlAppender.appendSql( "(select t.*,row_number() over (order by (select null)) " ); - sqlAppender.appendSql( ordinalityPart.asBasicValuedModelPart().getSelectableName() ); + sqlAppender.appendSql( ordinalityPart.asBasicValuedModelPart().getSelectionExpression() ); sqlAppender.appendSql( " from openjson(" ); } else { sqlAppender.appendSql( "openjson(" ); } array.accept( walker ); - sqlAppender.appendSql( ",'$[*]') with (" ); + sqlAppender.appendSql( ") with (" ); boolean[] comma = new boolean[1]; if ( tupleType.findSubPart( CollectionPart.Nature.ELEMENT.getName(), null ) == null ) { @@ -62,7 +62,7 @@ protected void renderJsonTable( sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); - sqlAppender.appendSql( " path '$." ); + sqlAppender.appendSql( " '$." ); sqlAppender.append( selectableMapping.getSelectableName() ); sqlAppender.appendSql( '\'' ); } @@ -81,7 +81,7 @@ protected void renderJsonTable( sqlAppender.append( selectableMapping.getSelectionExpression() ); sqlAppender.append( ' ' ); sqlAppender.append( getDdlType( selectableMapping, SqlTypes.JSON_ARRAY, walker ) ); - sqlAppender.appendSql( " path '$'" ); + sqlAppender.appendSql( " '$'" ); } } ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java index 9ab5f39111bf..afd06c189e60 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java @@ -10,7 +10,6 @@ import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.JsonExistsErrorBehavior; -import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonQueryEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonQueryErrorBehavior; import org.hibernate.sql.ast.tree.expression.JsonQueryWrapMode; @@ -44,23 +43,21 @@ protected void renderJsonTable(SqlAppender sqlAppender, JsonTableArguments argum arguments.jsonDocument().accept( walker ); if ( arguments.jsonPath() != null ) { sqlAppender.appendSql( ',' ); - final JsonPathPassingClause passingClause = arguments.passingClause(); - if ( passingClause != null ) { - JsonPathHelper.appendInlinedJsonPathIncludingPassingClause( - sqlAppender, - // Default behavior is NULL ON ERROR - arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ? "strict " : "", - arguments.jsonPath(), - passingClause, - walker - ); + // Default behavior is NULL ON ERROR + final String prefix = arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ? "strict " : ""; + final String jsonPathString; + if ( arguments.passingClause() != null ) { + jsonPathString = prefix + JsonPathHelper.inlinedJsonPathIncludingPassingClause( arguments.jsonPath(), + arguments.passingClause(), walker ); } else { - if ( arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ) { - // Default behavior is NULL ON ERROR - sqlAppender.appendSql( "'strict '+" ); - } - arguments.jsonPath().accept( walker ); + jsonPathString = prefix + walker.getLiteralValue( arguments.jsonPath() ); + } + if ( jsonPathString.endsWith( "[*]" ) ) { + sqlAppender.appendSingleQuoteEscapedString( jsonPathString.substring( 0, jsonPathString.length() - 3 ) ); + } + else { + sqlAppender.appendSingleQuoteEscapedString( jsonPathString ); } } else if ( arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ) { From 7737ceb7067b37c3419168b4110bd9bb384e8a4d Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Thu, 7 Nov 2024 16:52:25 +0100 Subject: [PATCH 7/7] HHH-16159 Fix some JSON related issues that came up --- docker_db.sh | 39 ++++-- .../aggregate/OracleAggregateSupport.java | 11 +- .../function/H2GenerateSeriesFunction.java | 4 +- .../json/CockroachDBJsonExistsFunction.java | 4 +- .../json/CockroachDBJsonQueryFunction.java | 4 +- .../json/CockroachDBJsonRemoveFunction.java | 4 +- .../json/CockroachDBJsonValueFunction.java | 4 +- .../function/json/H2JsonTableFunction.java | 113 ++++++++++++++---- .../json/PostgreSQLJsonQueryFunction.java | 4 +- .../json/PostgreSQLJsonRemoveFunction.java | 4 +- .../json/PostgreSQLJsonReplaceFunction.java | 4 +- .../json/PostgreSQLJsonSetFunction.java | 4 +- .../json/PostgreSQLJsonTableFunction.java | 4 +- .../json/PostgreSQLJsonValueFunction.java | 4 +- .../json/SQLServerJsonTableFunction.java | 27 +++-- 15 files changed, 168 insertions(+), 66 deletions(-) diff --git a/docker_db.sh b/docker_db.sh index db60b873e818..f6fcce3a3c29 100755 --- a/docker_db.sh +++ b/docker_db.sh @@ -1,6 +1,10 @@ #! /bin/bash -if command -v podman > /dev/null; then +if command -v docker > /dev/null; then + CONTAINER_CLI=$(command -v docker) + HEALTCHECK_PATH="{{.State.Health.Status}}" + PRIVILEGED_CLI="" +else CONTAINER_CLI=$(command -v podman) HEALTCHECK_PATH="{{.State.Healthcheck.Status}}" # Only use sudo for podman @@ -9,10 +13,6 @@ if command -v podman > /dev/null; then else PRIVILEGED_CLI="" fi -else - CONTAINER_CLI=$(command -v docker) - HEALTCHECK_PATH="{{.State.Health.Status}}" - PRIVILEGED_CLI="" fi mysql() { @@ -489,7 +489,7 @@ oracle_setup() { echo "Waiting for Oracle to start..." sleep 5; # On WSL, health-checks intervals don't work for Podman, so run them manually - if command -v podman > /dev/null; then + if ! command -v docker > /dev/null; then $PRIVILEGED_CLI $CONTAINER_CLI healthcheck run oracle > /dev/null fi HEALTHSTATUS="`$PRIVILEGED_CLI $CONTAINER_CLI inspect -f $HEALTCHECK_PATH oracle`" @@ -569,7 +569,7 @@ oracle_free_setup() { echo "Waiting for Oracle Free to start..." sleep 5; # On WSL, health-checks intervals don't work for Podman, so run them manually - if command -v podman > /dev/null; then + if ! command -v docker > /dev/null; then $PRIVILEGED_CLI $CONTAINER_CLI healthcheck run oracle > /dev/null fi HEALTHSTATUS="`$PRIVILEGED_CLI $CONTAINER_CLI inspect -f $HEALTCHECK_PATH oracle`" @@ -658,9 +658,13 @@ disable_userland_proxy() { echo "Stopping docker..." sudo service docker stop echo "Updating /etc/docker/daemon.json..." - sudo bash -c 'echo "${docker_daemon_json/\}/,}\"userland-proxy\": false}" > /etc/docker/daemon.json' + sudo bash -c "export docker_daemon_json='$docker_daemon_json'; echo \"\${docker_daemon_json/\}/,}\\\"userland-proxy\\\": false}\" > /etc/docker/daemon.json" + echo "New docker daemon config:" + cat /etc/docker/daemon.json echo "Starting docker..." sudo service docker start + echo "Service status:" + sudo journalctl -xeu docker.service echo "Docker successfully started with userland proxies disabled" fi fi @@ -733,6 +737,21 @@ oracle() { oracle_23 } +oracle_18() { + $PRIVILEGED_CLI $CONTAINER_CLI rm -f oracle || true + disable_userland_proxy + # We need to use the defaults + # SYSTEM/Oracle18 + $PRIVILEGED_CLI $CONTAINER_CLI run --name oracle -d -p 1521:1521 -e ORACLE_PASSWORD=Oracle18 \ + --cap-add cap_net_raw \ + --health-cmd healthcheck.sh \ + --health-interval 5s \ + --health-timeout 5s \ + --health-retries 10 \ + ${DB_IMAGE_ORACLE_21:-docker.io/gvenzl/oracle-xe:18.4.0} + oracle_setup +} + oracle_21() { $PRIVILEGED_CLI $CONTAINER_CLI rm -f oracle || true disable_userland_proxy @@ -765,7 +784,7 @@ oracle_23() { hana() { temp_dir=$(mktemp -d) echo '{"master_password" : "H1bernate_test"}' >$temp_dir/password.json - chmod 777 -R $temp_dir + chmod -R 777 $temp_dir $PRIVILEGED_CLI $CONTAINER_CLI rm -f hana || true $PRIVILEGED_CLI $CONTAINER_CLI run -d --name hana -p 39013:39013 -p 39017:39017 -p 39041-39045:39041-39045 -p 1128-1129:1128-1129 -p 59013-59014:59013-59014 \ --memory=8g \ @@ -775,7 +794,7 @@ hana() { --sysctl kernel.shmmni=4096 \ --sysctl kernel.shmall=8388608 \ -v $temp_dir:/config:Z \ - ${DB_IMAGE_HANA:-docker.io/saplabs/hanaexpress:2.00.072.00.20231123.1} \ + ${DB_IMAGE_HANA:-docker.io/saplabs/hanaexpress:2.00.076.00.20240701.1} \ --passwords-url file:///config/password.json \ --agree-to-sap-license # Give the container some time to start diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java index 172b1511bad2..6a3111edbe66 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/OracleAggregateSupport.java @@ -29,6 +29,8 @@ import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.sql.ast.tree.expression.Literal; import org.hibernate.type.BasicPluralType; import org.hibernate.type.BasicType; import org.hibernate.type.SqlTypes; @@ -589,7 +591,14 @@ public void append( // We use NO_UNTYPED here so that expressions which require type inference are casted explicitly, // since we don't know how the custom write expression looks like where this is embedded, // so we have to be pessimistic and avoid ambiguities - translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED ); + final Expression valueExpression = expression.getValueExpression( selectableMapping ); + if ( valueExpression instanceof Literal literal && literal.getLiteralValue() == null ) { + // Except for the null literal. That is just rendered as-is + sb.append( "null" ); + } + else { + translator.render( valueExpression, SqlAstNodeRenderingMode.NO_UNTYPED ); + } sb.append( customWriteExpressionEnd ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/H2GenerateSeriesFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/H2GenerateSeriesFunction.java index 8623318902b0..8686fa3ecde1 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/H2GenerateSeriesFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/H2GenerateSeriesFunction.java @@ -19,11 +19,11 @@ import org.hibernate.query.sqm.tree.SqmTypedNode; import org.hibernate.spi.NavigablePath; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.ColumnReference; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.Literal; import org.hibernate.sql.ast.tree.from.FunctionTableGroup; import org.hibernate.sql.ast.tree.from.TableGroup; @@ -232,6 +232,6 @@ protected void renderGenerateSeries( } private static boolean needsEmulation(Expression expression) { - return !( expression instanceof Literal || expression instanceof JdbcParameter); + return !( expression instanceof Literal || AbstractSqlAstTranslator.isParameter( expression ) ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonExistsFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonExistsFunction.java index 36f283f8f191..5756aaf194cc 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonExistsFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonExistsFunction.java @@ -11,9 +11,9 @@ import org.hibernate.dialect.Dialect; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.JsonExistsErrorBehavior; import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.type.spi.TypeConfiguration; @@ -61,7 +61,7 @@ static void appendJsonExists( boolean isJsonType, @Nullable JsonPathPassingClause jsonPathPassingClause, SqlAstTranslator walker) { - final boolean needsCast = !isJsonType && jsonDocument instanceof JdbcParameter; + final boolean needsCast = !isJsonType && AbstractSqlAstTranslator.isParameter( jsonDocument ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonQueryFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonQueryFunction.java index b0d537ff736a..512059172a81 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonQueryFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonQueryFunction.java @@ -11,9 +11,9 @@ import org.hibernate.dialect.Dialect; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonQueryEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonQueryErrorBehavior; @@ -76,7 +76,7 @@ static void appendJsonQuery( boolean isJsonType, @Nullable JsonPathPassingClause jsonPathPassingClause, SqlAstTranslator walker) { - final boolean needsCast = !isJsonType && jsonDocumentExpression instanceof JdbcParameter; + final boolean needsCast = !isJsonType && AbstractSqlAstTranslator.isParameter( jsonDocumentExpression ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonRemoveFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonRemoveFunction.java index 561212ddc3ea..a900f6b25c05 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonRemoveFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonRemoveFunction.java @@ -10,10 +10,10 @@ import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.type.spi.TypeConfiguration; /** @@ -34,7 +34,7 @@ public void render( final Expression json = (Expression) arguments.get( 0 ); final Expression jsonPath = (Expression) arguments.get( 1 ); sqlAppender.appendSql( "json_remove_path(" ); - final boolean needsCast = !isJsonType( json ) && json instanceof JdbcParameter; + final boolean needsCast = !isJsonType( json ) && AbstractSqlAstTranslator.isParameter( json ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonValueFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonValueFunction.java index d5736db070fb..a90db6cb367a 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonValueFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/CockroachDBJsonValueFunction.java @@ -10,10 +10,10 @@ import org.hibernate.dialect.Dialect; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.expression.CastTarget; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonValueEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonValueErrorBehavior; @@ -63,7 +63,7 @@ static void appendJsonValue(SqlAppender sqlAppender, Expression jsonDocument, Li if ( castTarget != null ) { sqlAppender.appendSql( "cast(" ); } - final boolean needsCast = !isJsonType && jsonDocument instanceof JdbcParameter; + final boolean needsCast = !isJsonType && AbstractSqlAstTranslator.isParameter( jsonDocument ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java index 8eadaff5812f..0d66bfb030b8 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java @@ -53,6 +53,7 @@ import org.hibernate.sql.ast.tree.from.TableGroupJoin; import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate; import org.hibernate.sql.ast.tree.predicate.Predicate; +import org.hibernate.sql.ast.tree.predicate.PredicateContainer; import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.type.BasicType; import org.hibernate.type.SqlTypes; @@ -153,12 +154,31 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo final TableGroup parentTableGroup = querySpec.getFromClause().queryTableGroups( tg -> tg.findTableGroupJoin( functionTableGroup ) == null ? null : tg ); - final TableGroupJoin join = parentTableGroup.findTableGroupJoin( functionTableGroup ); + final PredicateContainer predicateContainer; + if ( parentTableGroup != null ) { + predicateContainer = parentTableGroup.findTableGroupJoin( functionTableGroup ); + } + else { + predicateContainer = querySpec; + } final BasicType integerType = converter.getCreationContext() .getSessionFactory() .getNodeBuilder() .getIntegerType(); - final Expression lhs = new ArrayLengthExpression( arguments.jsonDocument(), integerType ); + final Expression jsonDocument; + if ( arguments.jsonDocument().getColumnReference() == null ) { + jsonDocument = new ColumnReference( + functionTableGroup.getPrimaryTableReference().getIdentificationVariable() + "_", + "d", + false, + null, + arguments.jsonDocument().getExpressionType().getSingleJdbcMapping() + ); + } + else { + jsonDocument = arguments.jsonDocument(); + } + final Expression lhs = new ArrayLengthExpression( jsonDocument, integerType ); final Expression rhs = new ColumnReference( functionTableGroup.getPrimaryTableReference().getIdentificationVariable(), // The default column name for the system_range function @@ -167,7 +187,7 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo null, integerType ); - join.applyPredicate( + predicateContainer.applyPredicate( new ComparisonPredicate( lhs, ComparisonOperator.GREATER_THAN_OR_EQUAL, rhs ) ); } final int lastArrayIndex = getLastArrayIndex( arguments.columnsClause(), 0 ); @@ -176,6 +196,19 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo // for every nested path for arrays final String tableIdentifierVariable = functionTableGroup.getPrimaryTableReference() .getIdentificationVariable(); + final Expression jsonDocument; + if ( arguments.jsonDocument().getColumnReference() == null ) { + jsonDocument = new ColumnReference( + tableIdentifierVariable + "_", + "d", + false, + null, + arguments.jsonDocument().getExpressionType().getSingleJdbcMapping() + ); + } + else { + jsonDocument = arguments.jsonDocument(); + } final TableGroup tableGroup = new FunctionTableGroup( functionTableGroup.getNavigablePath().append( "{synthetic}" ), null, @@ -184,6 +217,7 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo new NestedPathFunctionRenderer( tableIdentifierVariable, arguments, + jsonDocument, maximumArraySize, lastArrayIndex ), @@ -207,7 +241,7 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo // The join predicate compares the length of the last array expression against system_range() index. // Since a table function expression can't render its own `on` clause, this split of logic is necessary final Expression lhs = new ArrayLengthExpression( - determineLastArrayExpression( tableIdentifierVariable, arguments ), + determineLastArrayExpression( tableIdentifierVariable, arguments, jsonDocument ), integerType ); final Expression rhs = new ColumnReference( @@ -226,10 +260,10 @@ public QuerySpec transform(CteContainer cteContainer, QuerySpec querySpec, SqmTo return querySpec; } - private static Expression determineLastArrayExpression(String tableIdentifierVariable, JsonTableArguments arguments) { + private static Expression determineLastArrayExpression(String tableIdentifierVariable, JsonTableArguments arguments, Expression jsonDocument) { final ArrayExpressionEntry arrayExpressionEntry = determineLastArrayExpression( tableIdentifierVariable, - determineJsonElement( tableIdentifierVariable, arguments ), + determineJsonElement( tableIdentifierVariable, arguments, jsonDocument ), arguments.columnsClause(), new ArrayExpressionEntry( 0, null ) ); @@ -253,7 +287,7 @@ private static ArrayExpressionEntry determineLastArrayExpression(String tableIde final ArrayExpressionEntry nextArrayExpression; if ( isArray ) { final int nextArrayIndex = currentArrayEntry.arrayIndex() + 1; - jsonElement = new ArrayAccessExpression( jsonQueryResult, tableIdentifierVariable + "_" + nextArrayIndex + "_.x" ); + jsonElement = new ArrayAccessExpression( jsonQueryResult, ordinalityExpression( tableIdentifierVariable, nextArrayIndex ) ); nextArrayExpression = new ArrayExpressionEntry( nextArrayIndex, jsonQueryResult ); } else { @@ -271,10 +305,9 @@ private static ArrayExpressionEntry determineLastArrayExpression(String tableIde return currentArrayEntry; } - private static Expression determineJsonElement(String tableIdentifierVariable, JsonTableArguments arguments) { + private static Expression determineJsonElement(String tableIdentifierVariable, JsonTableArguments arguments, Expression jsonDocument) { // Applies the json path and array index access to obtain the "current" processing element - final Expression jsonDocument = arguments.jsonDocument(); final boolean isArray; final Expression jsonQueryResult; if ( arguments.jsonPath() != null ) { @@ -309,19 +342,21 @@ private static Expression determineJsonElement(String tableIdentifierVariable, J private static class NestedPathFunctionRenderer implements FunctionRenderer { private final String tableIdentifierVariable; private final JsonTableArguments arguments; + private final Expression jsonDocument; private final int maximumArraySize; private final int lastArrayIndex; - public NestedPathFunctionRenderer(String tableIdentifierVariable, JsonTableArguments arguments, int maximumArraySize, int lastArrayIndex) { + public NestedPathFunctionRenderer(String tableIdentifierVariable, JsonTableArguments arguments, Expression jsonDocument, int maximumArraySize, int lastArrayIndex) { this.tableIdentifierVariable = tableIdentifierVariable; this.arguments = arguments; + this.jsonDocument = jsonDocument; this.maximumArraySize = maximumArraySize; this.lastArrayIndex = lastArrayIndex; } @Override public void render(SqlAppender sqlAppender, List sqlAstArguments, ReturnableType returnType, SqlAstTranslator walker) { - final Expression jsonElement = determineJsonElement( tableIdentifierVariable, arguments ); + final Expression jsonElement = determineJsonElement( tableIdentifierVariable, arguments, jsonDocument ); renderNestedColumnJoins( sqlAppender, tableIdentifierVariable, jsonElement, arguments.columnsClause(), 0, lastArrayIndex, walker ); } @@ -352,17 +387,15 @@ private int renderNestedColumnJoins(SqlAppender sqlAppender, String tableIdentif sqlAppender.appendSql( nextArrayIndex ); sqlAppender.appendSql( '_' ); + final String ordinalityExpression = ordinalityExpression( tableIdentifierVariable, nextArrayIndex ); // The join condition for the last array will be rendered via TableGroupJoin if ( nextArrayIndex != lastArrayIndex ) { sqlAppender.appendSql( " on coalesce(array_length(" ); jsonQueryResult.accept( walker ); sqlAppender.append( "),0)>=" ); - sqlAppender.appendSql( tableIdentifierVariable ); - sqlAppender.appendSql( '_' ); - sqlAppender.appendSql( nextArrayIndex ); - sqlAppender.appendSql( "_.x" ); + sqlAppender.appendSql( ordinalityExpression ); } - jsonElement = new ArrayAccessExpression( jsonQueryResult, tableIdentifierVariable + "_" + nextArrayIndex + "_.x" ); + jsonElement = new ArrayAccessExpression( jsonQueryResult, ordinalityExpression ); } else { jsonElement = jsonQueryResult; @@ -383,6 +416,12 @@ private int renderNestedColumnJoins(SqlAppender sqlAppender, String tableIdentif } } + @Override + public boolean rendersIdentifierVariable(List arguments, SessionFactoryImplementor sessionFactory) { + // To make our lives simpler when supporting non-column JSON document arguments + return true; + } + @Override protected void renderJsonTable( SqlAppender sqlAppender, @@ -397,13 +436,27 @@ protected void renderJsonTable( final Expression jsonPathExpression = arguments.jsonPath(); final boolean isArray = isArrayAccess( jsonPathExpression, walker ); + if ( arguments.jsonDocument().getColumnReference() == null ) { + sqlAppender.append( '(' ); + } if ( isArray ) { sqlAppender.append( "system_range(1," ); sqlAppender.append( Integer.toString( maximumArraySize ) ); - sqlAppender.append( ")" ); + sqlAppender.append( ") " ); } else { - sqlAppender.append( "system_range(1,1)" ); + sqlAppender.append( "system_range(1,1) " ); + } + sqlAppender.append( tableIdentifierVariable ); + if ( arguments.jsonDocument().getColumnReference() == null ) { + sqlAppender.append( " join (values (" ); + arguments.jsonDocument().accept( walker ); + if ( !arguments.isJsonType() ) { + sqlAppender.append( " format json" ); + } + sqlAppender.append( ")) " ); + sqlAppender.append( tableIdentifierVariable ); + sqlAppender.append( "_(d) on 1=1)" ); } } @@ -526,6 +579,13 @@ public JdbcMappingContainer getExpressionType() { } } + private static String ordinalityExpression(String tableIdentifierVariable, int clauseLevel) { + if ( clauseLevel == 0 ) { + return tableIdentifierVariable + ".x"; + } + return tableIdentifierVariable + "_" + clauseLevel + "_.x"; + } + /** * This type resolver essentially implements all the JSON path handling and casting via column read expressions * instead of rendering to the {@code from} clause like other {@code json_table()} implementations. @@ -545,10 +605,15 @@ public SelectableMapping[] resolveFunctionReturnType( boolean withOrdinality, SqmToSqlAstConverter converter) { final JsonTableArguments arguments = JsonTableArguments.extract( sqlAstNodes ); - final ColumnReference columnReference = arguments.jsonDocument().getColumnReference(); - assert columnReference != null; - - final String documentPath = columnReference.getExpressionText(); + final Expression jsonDocument = arguments.jsonDocument(); + final String documentPath; + final ColumnReference columnReference = jsonDocument.getColumnReference(); + if ( columnReference != null ) { + documentPath = columnReference.getExpressionText(); + } + else { + documentPath = tableIdentifierVariable + "_." + "d"; + } final String parentPath; final boolean isArray; @@ -620,7 +685,7 @@ protected int addSelectableMappings(List selectableMappings, final String readExpression; if ( isArray ) { nextClauseLevel = clauseLevel + 1; - readExpression = "array_get(" + parentPath + "," + tableIdentifierVariable + "_" + nextClauseLevel + "_.x)"; + readExpression = "array_get(" + parentPath + "," + ordinalityExpression( tableIdentifierVariable, nextClauseLevel ) + ")"; } else { nextClauseLevel = clauseLevel; @@ -633,7 +698,7 @@ protected void addSelectableMappings(List selectableMappings, addSelectableMapping( selectableMappings, definition.name(), - tableIdentifierVariable + "_" + clauseLevel + "_.x", + ordinalityExpression( tableIdentifierVariable, clauseLevel ), converter.getCreationContext().getTypeConfiguration().getBasicTypeForJavaType( Long.class ) ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonQueryFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonQueryFunction.java index 59237bd48418..1df0d6c26779 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonQueryFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonQueryFunction.java @@ -10,10 +10,10 @@ import org.hibernate.QueryException; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonQueryEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonQueryErrorBehavior; @@ -65,7 +65,7 @@ else if ( wrapMode == JsonQueryWrapMode.WITH_CONDITIONAL_WRAPPER ) { else { sqlAppender.appendSql( "(select t.v from jsonb_path_query(" ); } - final boolean needsCast = !isJsonType && jsonDocument instanceof JdbcParameter; + final boolean needsCast = !isJsonType && AbstractSqlAstTranslator.isParameter( jsonDocument ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonRemoveFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonRemoveFunction.java index 9bb175d5e6c3..b1ce489798e8 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonRemoveFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonRemoveFunction.java @@ -10,10 +10,10 @@ import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.type.spi.TypeConfiguration; /** @@ -33,7 +33,7 @@ public void render( SqlAstTranslator translator) { final Expression json = (Expression) arguments.get( 0 ); final Expression jsonPath = (Expression) arguments.get( 1 ); - final boolean needsCast = !isJsonType( json ) && json instanceof JdbcParameter; + final boolean needsCast = !isJsonType( json ) && AbstractSqlAstTranslator.isParameter( json ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonReplaceFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonReplaceFunction.java index 6aba8ff7b82e..651ec2f28eaf 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonReplaceFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonReplaceFunction.java @@ -10,10 +10,10 @@ import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.Literal; import org.hibernate.type.spi.TypeConfiguration; @@ -36,7 +36,7 @@ public void render( final Expression jsonPath = (Expression) arguments.get( 1 ); final SqlAstNode value = arguments.get( 2 ); sqlAppender.appendSql( "jsonb_set(" ); - final boolean needsCast = !isJsonType( json ) && json instanceof JdbcParameter; + final boolean needsCast = !isJsonType( json ) && AbstractSqlAstTranslator.isParameter( json ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonSetFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonSetFunction.java index e5f02d9aa3f0..03728540f979 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonSetFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonSetFunction.java @@ -10,10 +10,10 @@ import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.Literal; import org.hibernate.type.spi.TypeConfiguration; @@ -36,7 +36,7 @@ public void render( final Expression jsonPath = (Expression) arguments.get( 1 ); final SqlAstNode value = arguments.get( 2 ); sqlAppender.appendSql( "jsonb_set(" ); - final boolean needsCast = !isJsonType( json ) && json instanceof JdbcParameter; + final boolean needsCast = !isJsonType( json ) && AbstractSqlAstTranslator.isParameter( json ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonTableFunction.java index 95555792a13c..4a7d5b102f32 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonTableFunction.java @@ -9,10 +9,10 @@ import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.derived.AnonymousTupleTableGroupProducer; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.JsonExistsErrorBehavior; import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonQueryEmptyBehavior; @@ -59,7 +59,7 @@ protected void renderJsonTable( sqlAppender.appendSql( " from jsonb_path_query(" ); - final boolean needsCast = !arguments.isJsonType() && arguments.jsonDocument() instanceof JdbcParameter; + final boolean needsCast = !arguments.isJsonType() && AbstractSqlAstTranslator.isParameter( arguments.jsonDocument() ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonValueFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonValueFunction.java index 592f43daa06d..bdedd46ef503 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonValueFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/PostgreSQLJsonValueFunction.java @@ -10,11 +10,11 @@ import org.hibernate.QueryException; import org.hibernate.query.ReturnableType; import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.tree.SqlAstNode; import org.hibernate.sql.ast.tree.expression.CastTarget; import org.hibernate.sql.ast.tree.expression.Expression; -import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.expression.JsonPathPassingClause; import org.hibernate.sql.ast.tree.expression.JsonValueEmptyBehavior; import org.hibernate.sql.ast.tree.expression.JsonValueErrorBehavior; @@ -60,7 +60,7 @@ static void appendJsonValue(SqlAppender sqlAppender, Expression jsonDocument, Sq sqlAppender.appendSql( "cast(" ); } sqlAppender.appendSql( "jsonb_path_query_first(" ); - final boolean needsCast = !isJsonType && jsonDocument instanceof JdbcParameter; + final boolean needsCast = !isJsonType && AbstractSqlAstTranslator.isParameter( jsonDocument ); if ( needsCast ) { sqlAppender.appendSql( "cast(" ); } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java index afd06c189e60..55f4c12957b3 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/SQLServerJsonTableFunction.java @@ -43,22 +43,31 @@ protected void renderJsonTable(SqlAppender sqlAppender, JsonTableArguments argum arguments.jsonDocument().accept( walker ); if ( arguments.jsonPath() != null ) { sqlAppender.appendSql( ',' ); - // Default behavior is NULL ON ERROR - final String prefix = arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ? "strict " : ""; - final String jsonPathString; + final String rawJsonPath; if ( arguments.passingClause() != null ) { - jsonPathString = prefix + JsonPathHelper.inlinedJsonPathIncludingPassingClause( arguments.jsonPath(), - arguments.passingClause(), walker ); + rawJsonPath = JsonPathHelper.inlinedJsonPathIncludingPassingClause( + arguments.jsonPath(), + arguments.passingClause(), + walker + ); } else { - jsonPathString = prefix + walker.getLiteralValue( arguments.jsonPath() ); + rawJsonPath = walker.getLiteralValue( arguments.jsonPath() ); } - if ( jsonPathString.endsWith( "[*]" ) ) { - sqlAppender.appendSingleQuoteEscapedString( jsonPathString.substring( 0, jsonPathString.length() - 3 ) ); + final String jsonPath; + if ( arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ) { + // Default behavior is NULL ON ERROR + jsonPath = "strict " + rawJsonPath; } else { - sqlAppender.appendSingleQuoteEscapedString( jsonPathString ); + jsonPath = rawJsonPath; } + sqlAppender.appendSingleQuoteEscapedString( + // openjson unwraps arrays automatically and doesn't support this syntax, so remove it + jsonPath.endsWith( "[*]" ) + ? jsonPath.substring( 0, jsonPath.length() - 3 ) + : jsonPath + ); } else if ( arguments.errorBehavior() == JsonTableErrorBehavior.ERROR ) { // Default behavior is NULL ON ERROR