Skip to content

Commit

Permalink
HHH-17355 Smoothen some rough edges with parameter typing and PG12 su…
Browse files Browse the repository at this point in the history
…pport
  • Loading branch information
beikov committed Nov 6, 2023
1 parent d7bdb5c commit c700dcd
Show file tree
Hide file tree
Showing 41 changed files with 374 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionFactory.arrayTrim_trim_array();
functionFactory.arrayFill_postgresql();
functionFactory.arrayTrim_unnest();
functionFactory.arrayFill_cockroachdb();
functionFactory.arrayToString_postgresql();

functionContributions.getFunctionRegistry().register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,12 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionFactory.arrayTrim_trim_array();
if ( getVersion().isSameOrAfter( 14 ) ) {
functionFactory.arrayTrim_trim_array();
}
else {
functionFactory.arrayTrim_unnest();
}
functionFactory.arrayFill_postgresql();
functionFactory.arrayToString_postgresql();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionFactory.arrayTrim_trim_array();
functionFactory.arrayFill_postgresql();
functionFactory.arrayTrim_unnest();
functionFactory.arrayFill_cockroachdb();
functionFactory.arrayToString_postgresql();

functionContributions.getFunctionRegistry().register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public static String getSqlType(CastTarget castTarget, Dialect dialect) {
}

public static String getSqlType(CastTarget castTarget, SessionFactoryImplementor factory) {
final String sqlType = getCastTypeName( castTarget, factory );
final String sqlType = getCastTypeName( castTarget, factory.getTypeConfiguration() );
return getSqlType( castTarget, sqlType, factory.getJdbcServices().getDialect() );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,12 @@ public void initializeFunctionRegistry(FunctionContributions functionContributio
functionFactory.arrayRemoveIndex_unnest( true );
functionFactory.arraySlice_operator();
functionFactory.arrayReplace();
functionFactory.arrayTrim_trim_array();
if ( getVersion().isSameOrAfter( 14 ) ) {
functionFactory.arrayTrim_trim_array();
}
else {
functionFactory.arrayTrim_unnest();
}
functionFactory.arrayFill_postgresql();
functionFactory.arrayToString_postgresql();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.hibernate.dialect.function.array.ArraySliceUnnestFunction;
import org.hibernate.dialect.function.array.ArrayToStringFunction;
import org.hibernate.dialect.function.array.ArrayViaArgumentReturnTypeResolver;
import org.hibernate.dialect.function.array.CockroachArrayFillFunction;
import org.hibernate.dialect.function.array.ElementViaArrayArgumentReturnTypeResolver;
import org.hibernate.dialect.function.array.H2ArrayContainsFunction;
import org.hibernate.dialect.function.array.H2ArrayFillFunction;
Expand Down Expand Up @@ -72,6 +73,7 @@
import org.hibernate.dialect.function.array.OracleArrayConstructorFunction;
import org.hibernate.dialect.function.array.OracleArrayContainsFunction;
import org.hibernate.dialect.function.array.PostgreSQLArrayPositionsFunction;
import org.hibernate.dialect.function.array.PostgreSQLArrayTrimEmulation;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
Expand Down Expand Up @@ -3166,6 +3168,13 @@ public void arrayTrim_trim_array() {
.register();
}

/**
* PostgreSQL array_trim() emulation for versions before 14
*/
public void arrayTrim_unnest() {
functionRegistry.register( "array_trim", new PostgreSQLArrayTrimEmulation() );
}

/**
* Oracle array_trim() function
*/
Expand Down Expand Up @@ -3197,6 +3206,14 @@ public void arrayFill_postgresql() {
functionRegistry.register( "array_fill_list", new PostgreSQLArrayFillFunction( true ) );
}

/**
* Cockroach array_fill() function
*/
public void arrayFill_cockroachdb() {
functionRegistry.register( "array_fill", new CockroachArrayFillFunction( false ) );
functionRegistry.register( "array_fill_list", new CockroachArrayFillFunction( true ) );
}

/**
* Oracle array_fill() function
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;

import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;

import static org.hibernate.query.sqm.produce.function.FunctionParameterType.ANY;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;

/**
* Encapsulates the validator, return type and argument type resolvers for the array_remove functions.
* Subclasses only have to implement the rendering.
*/
public abstract class AbstractArrayTrimFunction extends AbstractSqmSelfRenderingFunctionDescriptor {

public AbstractArrayTrimFunction() {
super(
"array_trim",
StandardArgumentsValidators.composite(
new ArgumentTypesValidator( null, ANY, INTEGER ),
ArrayArgumentValidator.DEFAULT_INSTANCE
),
ArrayViaArgumentReturnTypeResolver.DEFAULT_INSTANCE,
StandardFunctionArgumentTypeResolvers.composite(
StandardFunctionArgumentTypeResolvers.invariant( ANY, INTEGER ),
StandardFunctionArgumentTypeResolvers.IMPLIED_RESULT_TYPE
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void render(
sqlAppender.append( "] as " );
sqlAppender.append( DdlTypeHelper.getCastTypeName(
haystackExpression.getExpressionType(),
walker
walker.getSessionFactory().getTypeConfiguration()
) );
sqlAppender.append( ')' );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ public void render(
sqlAppender.append( "),");
if ( castEmptyArrayLiteral ) {
sqlAppender.append( "cast(array[] as " );
sqlAppender.append( DdlTypeHelper.getCastTypeName( returnType, walker ) );
sqlAppender.append( DdlTypeHelper.getCastTypeName(
returnType,
walker.getSessionFactory().getTypeConfiguration()
) );
sqlAppender.append( ')' );
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ public void render(
sqlAppender.append( "),");
if ( castEmptyArrayLiteral ) {
sqlAppender.append( "cast(array[] as " );
sqlAppender.append( DdlTypeHelper.getCastTypeName( returnType, walker ) );
sqlAppender.append( DdlTypeHelper.getCastTypeName(
returnType,
walker.getSessionFactory().getTypeConfiguration()
) );
sqlAppender.append( ')' );
}
else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function.array;

import java.util.List;

import org.hibernate.query.ReturnableType;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.expression.Literal;

/**
* Implement the array fill function by using {@code generate_series}.
*/
public class CockroachArrayFillFunction extends AbstractArrayFillFunction {

public CockroachArrayFillFunction(boolean list) {
super( list );
}

@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> sqlAstArguments,
ReturnableType<?> returnType,
SqlAstTranslator<?> walker) {
sqlAppender.append( "coalesce(case when " );
sqlAstArguments.get( 1 ).accept( walker );
sqlAppender.append( "<>0 then (select array_agg(" );
final String elementCastType;
final Expression elementExpression = (Expression) sqlAstArguments.get( 0 );
if ( needsElementCasting( elementExpression ) ) {
elementCastType = DdlTypeHelper.getCastTypeName(
elementExpression.getExpressionType(),
walker.getSessionFactory().getTypeConfiguration()
);
sqlAppender.append( "cast(" );
}
else {
elementCastType = null;
}
sqlAstArguments.get( 0 ).accept( walker );
if ( elementCastType != null ) {
sqlAppender.append( " as " );
sqlAppender.append( elementCastType );
sqlAppender.append( ')' );
}
sqlAppender.append( ") from generate_series(1," );
sqlAstArguments.get( 1 ).accept( walker );
sqlAppender.append( ",1))) end,array[])" );
}

private static boolean needsElementCasting(Expression elementExpression) {
// PostgreSQL needs casting of null and string literal expressions
return elementExpression instanceof Literal && (
elementExpression.getExpressionType().getSingleJdbcMapping().getJdbcType().isString()
|| ( (Literal) elementExpression ).getLiteralValue() == null
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.hibernate.metamodel.mapping.SqlTypedMapping;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.ReturnableType;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator;
import org.hibernate.type.BasicType;
import org.hibernate.type.descriptor.java.BasicPluralJavaType;
Expand Down Expand Up @@ -60,71 +59,91 @@ public static BasicType<?> resolveListType(DomainType<?> elementType, TypeConfig
);
}

public static String getTypeName(BasicType<?> type, SqlAstTranslator<?> walker) {
return getTypeName( (JdbcMappingContainer) type, walker );
public static String getTypeName(BasicType<?> type, TypeConfiguration typeConfiguration) {
return getTypeName( (JdbcMappingContainer) type, typeConfiguration );
}

public static String getTypeName(JdbcMappingContainer type, SqlAstTranslator<?> walker) {
public static String getTypeName(BasicType<?> type, Size size, TypeConfiguration typeConfiguration) {
return getTypeName( (JdbcMappingContainer) type, size, typeConfiguration );
}

public static String getTypeName(JdbcMappingContainer type, TypeConfiguration typeConfiguration) {
return getTypeName( type, Size.nil(), typeConfiguration );
}

public static String getTypeName(JdbcMappingContainer type, Size size, TypeConfiguration typeConfiguration) {
if ( type instanceof SqlTypedMapping ) {
return AbstractSqlAstTranslator.getSqlTypeName( (SqlTypedMapping) type, walker.getSessionFactory() );
return AbstractSqlAstTranslator.getSqlTypeName( (SqlTypedMapping) type, typeConfiguration );
}
else {
final BasicType<?> basicType = (BasicType<?>) type.getSingleJdbcMapping();
final TypeConfiguration typeConfiguration = walker.getSessionFactory().getTypeConfiguration();
final DdlTypeRegistry ddlTypeRegistry = typeConfiguration.getDdlTypeRegistry();
final DdlType ddlType = ddlTypeRegistry.getDescriptor(
basicType.getJdbcType().getDdlTypeCode()
);
return ddlType.getTypeName( Size.nil(), basicType, ddlTypeRegistry );
return ddlType.getTypeName( size, basicType, ddlTypeRegistry );
}
}

public static String getTypeName(ReturnableType<?> type, SqlAstTranslator<?> walker) {
public static String getTypeName(ReturnableType<?> type, TypeConfiguration typeConfiguration) {
return getTypeName( type, Size.nil(), typeConfiguration );
}

public static String getTypeName(ReturnableType<?> type, Size size, TypeConfiguration typeConfiguration) {
if ( type instanceof SqlTypedMapping ) {
return AbstractSqlAstTranslator.getSqlTypeName( (SqlTypedMapping) type, walker.getSessionFactory() );
return AbstractSqlAstTranslator.getSqlTypeName( (SqlTypedMapping) type, typeConfiguration );
}
else {
final BasicType<?> basicType = (BasicType<?>) ( (JdbcMappingContainer) type ).getSingleJdbcMapping();
final TypeConfiguration typeConfiguration = walker.getSessionFactory().getTypeConfiguration();
final DdlTypeRegistry ddlTypeRegistry = typeConfiguration.getDdlTypeRegistry();
final DdlType ddlType = ddlTypeRegistry.getDescriptor(
basicType.getJdbcType().getDdlTypeCode()
);
return ddlType.getTypeName( Size.nil(), basicType, ddlTypeRegistry );
return ddlType.getTypeName( size, basicType, ddlTypeRegistry );
}
}

public static String getCastTypeName(BasicType<?> type, SqlAstTranslator<?> walker) {
return getCastTypeName( (JdbcMappingContainer) type, walker );
public static String getCastTypeName(BasicType<?> type, TypeConfiguration typeConfiguration) {
return getCastTypeName( (JdbcMappingContainer) type, typeConfiguration );
}

public static String getCastTypeName(JdbcMappingContainer type, SqlAstTranslator<?> walker) {
public static String getCastTypeName(BasicType<?> type, Size size, TypeConfiguration typeConfiguration) {
return getCastTypeName( (JdbcMappingContainer) type, size, typeConfiguration );
}

public static String getCastTypeName(JdbcMappingContainer type, TypeConfiguration typeConfiguration) {
return getCastTypeName( type, Size.nil(), typeConfiguration );
}

public static String getCastTypeName(JdbcMappingContainer type, Size size, TypeConfiguration typeConfiguration) {
if ( type instanceof SqlTypedMapping ) {
return AbstractSqlAstTranslator.getCastTypeName( (SqlTypedMapping) type, walker.getSessionFactory() );
return AbstractSqlAstTranslator.getCastTypeName( (SqlTypedMapping) type, typeConfiguration );
}
else {
final BasicType<?> basicType = (BasicType<?>) type.getSingleJdbcMapping();
final TypeConfiguration typeConfiguration = walker.getSessionFactory().getTypeConfiguration();
final DdlTypeRegistry ddlTypeRegistry = typeConfiguration.getDdlTypeRegistry();
final DdlType ddlType = ddlTypeRegistry.getDescriptor(
basicType.getJdbcType().getDdlTypeCode()
);
return ddlType.getCastTypeName( Size.nil(), basicType, ddlTypeRegistry );
return ddlType.getCastTypeName( size, basicType, ddlTypeRegistry );
}
}

public static String getCastTypeName(ReturnableType<?> type, SqlAstTranslator<?> walker) {
public static String getCastTypeName(ReturnableType<?> type, TypeConfiguration typeConfiguration) {
return getCastTypeName( type, Size.nil(), typeConfiguration );
}

public static String getCastTypeName(ReturnableType<?> type, Size size, TypeConfiguration typeConfiguration) {
if ( type instanceof SqlTypedMapping ) {
return AbstractSqlAstTranslator.getCastTypeName( (SqlTypedMapping) type, walker.getSessionFactory() );
return AbstractSqlAstTranslator.getCastTypeName( (SqlTypedMapping) type, typeConfiguration );
}
else {
final BasicType<?> basicType = (BasicType<?>) ( (JdbcMappingContainer) type ).getSingleJdbcMapping();
final TypeConfiguration typeConfiguration = walker.getSessionFactory().getTypeConfiguration();
final DdlTypeRegistry ddlTypeRegistry = typeConfiguration.getDdlTypeRegistry();
final DdlType ddlType = ddlTypeRegistry.getDescriptor(
basicType.getJdbcType().getDdlTypeCode()
);
return ddlType.getCastTypeName( Size.nil(), basicType, ddlTypeRegistry );
return ddlType.getCastTypeName( size, basicType, ddlTypeRegistry );
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ public void render(
SqlAstTranslator<?> walker) {
final String castTypeName;
if ( returnType != null && hasOnlyBottomArguments( arguments ) ) {
castTypeName = DdlTypeHelper.getCastTypeName( returnType, walker );
castTypeName = DdlTypeHelper.getCastTypeName(
returnType,
walker.getSessionFactory().getTypeConfiguration()
);
sqlAppender.append( "cast(" );
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void render(
final String arrayTypeName = DdlTypeHelper.getTypeName(
prepend ? secondArgument.getExpressionType()
: firstArgument.getExpressionType(),
walker
walker.getSessionFactory().getTypeConfiguration()
);
sqlAppender.append( arrayTypeName );
sqlAppender.append( "_concat(" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ public void render(
List<? extends SqlAstNode> sqlAstArguments,
ReturnableType<?> returnType,
SqlAstTranslator<?> walker) {
final String arrayTypeName = DdlTypeHelper.getTypeName( (JdbcMappingContainer) returnType, walker );
final String arrayTypeName = DdlTypeHelper.getTypeName(
(JdbcMappingContainer) returnType,
walker.getSessionFactory().getTypeConfiguration()
);
sqlAppender.append( arrayTypeName );
sqlAppender.append( "_concat" );
super.render( sqlAppender, sqlAstArguments, returnType, walker );
Expand Down

0 comments on commit c700dcd

Please sign in to comment.