Skip to content

Commit

Permalink
HHH-17357 Add hibernate-types module with pgvector support
Browse files Browse the repository at this point in the history
  • Loading branch information
beikov committed Nov 7, 2023
1 parent 0291006 commit eebb305
Show file tree
Hide file tree
Showing 22 changed files with 1,178 additions and 89 deletions.
4 changes: 4 additions & 0 deletions docker_db.sh
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,26 @@ postgresql() {
postgresql_12() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:12-3.4
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-12-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}

postgresql_13() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:13-3.1
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-13-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}

postgresql_14() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:14-3.3
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-14-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}

postgresql_15() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 --tmpfs /pgtmpfs:size=131072k -d docker.io/postgis/postgis:15-3.3 \
-c fsync=off -c synchronous_commit=off -c full_page_writes=off -c shared_buffers=256MB -c maintenance_work_mem=256MB -c max_wal_size=1GB -c checkpoint_timeout=1d
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-15-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}

edb() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include::chapters/query/hql/QueryLanguage.adoc[]
include::chapters/query/criteria/Criteria.adoc[]
include::chapters/query/native/Native.adoc[]
include::chapters/query/spatial/Spatial.adoc[]
include::chapters/query/types/TypesModule.adoc[]
include::chapters/multitenancy/MultiTenancy.adoc[]
include::chapters/envers/Envers.adoc[]
include::chapters/beans/Beans.adoc[]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
[[types-module]]
== Hibernate Types module
:root-project-dir: ../../../../../../../..
:types-project-dir: {root-project-dir}/hibernate-types
:example-dir-types: {types-project-dir}/src/test/java/org/hibernate/types
:extrasdir: extras

[[types-module-overview]]
=== Overview

The Hibernate ORM core module tries to be as minimal as possible and only model functionality
that is somewhat "standard" in the SQL space or can only be modeled as part of the core module.
To avoid growing that module further unnecessarily, support for certain special SQL types or functions
is separated out into the Hibernate ORM types module.

[[types-module-setup]]
=== Setup

You need to include the `hibernate-types` dependency in your build environment.
For Maven, you need to add the following dependency:

[[types-module-setup-maven-example]]
.Maven dependency
====
[source,xml]
----
<dependency>
<groupId>org.hibernate.orm</groupId>
<artifactId>hibernate-types</artifactId>
<version>${hibernate.version}</version>
</dependency>
----
====

The module contains service implementations that are picked up by the Java `ServiceLoader` automatically,
so no further configuration is necessary to make the features available.

[[types-module-vector]]
=== Vector type support

The Hibernate ORM types module comes with support for a special `vector` data type that essentially represents an array of floats.

So far, only the PostgreSQL extension `pgvector` is supported, but in theory,
the vector specific functions could be implemented to work with every database that supports arrays.

For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation].

[[types-module-vector-usage]]
==== Usage

Annotate a persistent attribute with `@JdbcTypeCode(SqlTypes.VECTOR)` and specify the vector length with `@Array(length = ...)`.

[[types-module-vector-usage-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=usage-example]
----
====

To cast the string representation of a vector to the vector data type, simply use an HQL cast i.e. `cast('[1,2,3]' as vector)`.

[[types-module-vector-functions]]
==== Functions

Expressions of the vector type can be used with various vector functions.

[[types-module-vector-functions-overview]]
|===
| Function | Purpose

| `cosine_distance()` | Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors. Maps to the `<``=``>` operator
| `euclidean_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors. Maps to the `<``-``>` operator
| `l2_distance()` | Alias for `euclidean_distance()`
| `taxicab_distance()` | Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors
| `l1_distance()` | Alias for `taxicab_distance()`
| `inner_product()` | Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors
| `negative_inner_product()` | Computes the negative inner product. Maps to the `<``#``>` operator
| `vector_dims()` | Determines the dimensions of a vector
| `vector_norm()` | Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector
|===

In addition to these special vector functions, it is also possible to use vectors with the following builtin operators

`<vector1> + <vector2> = <vector3>`:: Element-wise addition of vectors.
`<vector1> - <vector2> = <vector3>`:: Element-wise subtraction of vectors.
`<vector1> * <vector2> = <vector3>`:: Element-wise multiplication of vectors.
`sum(<vector1>) = <vector2>`:: Aggregate function support for element-wise summation of vectors.
`avg(<vector1>) = <vector2>`:: Aggregate function support for element-wise average of vectors.

[[types-module-vector-functions-cosine-distance]]
===== `cosine_distance()`

Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors,
which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 ) )`. Maps to the `<``=``>` pgvector operator.

[[types-module-vector-functions-cosine-distance-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=cosine-distance-example]
----
====

[[types-module-vector-functions-euclidean-distance]]
===== `euclidean_distance()` and `l2_distance()`

Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors,
which is `sqrt( sum( (v1_i - v2_i)^2 ) )`. Maps to the `<``-``>` pgvector operator.
The `l2_distance()` function is an alias.

[[types-module-vector-functions-euclidean-distance-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=euclidean-distance-example]
----
====

[[types-module-vector-functions-taxicab-distance]]
===== `taxicab_distance()` and `l1_distance()`

Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors,
which is `vector_norm(v1) - vector_norm(v2)`.
The `l1_distance()` function is an alias.

[[types-module-vector-functions-taxicab-distance-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=taxicab-distance-example]
----
====

[[types-module-vector-functions-inner-product]]
===== `inner_product()` and `negative_inner_product()`

Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors,
which is `sum( v1_i * v2_i )`. The `negative_inner_product()` function maps to the `<``#``>` pgvector operator,
and the `inner_product()` function as well, but multiplies the result time `-1`.

[[types-module-vector-functions-inner-product-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=inner-product-example]
----
====

[[types-module-vector-functions-vector-dims]]
===== `vector_dims()`

Determines the dimensions of a vector.

[[types-module-vector-functions-vector-dims-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-dims-example]
----
====

[[types-module-vector-functions-vector-norm]]
===== `vector_norm()`

Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector,
which is `sqrt( sum( v_i^2 ) )`.

[[types-module-vector-functions-vector-norm-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-norm-example]
----
====




Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,23 @@

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.function.Supplier;

import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.FunctionKind;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.sql.ast.Clause;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
Expand All @@ -27,8 +34,14 @@
import org.hibernate.sql.ast.tree.expression.Distinct;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.BasicType;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.ObjectJdbcType;
import org.hibernate.type.spi.TypeConfiguration;

import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
Expand All @@ -49,10 +62,8 @@ public AvgFunction(
super(
"avg",
FunctionKind.AGGREGATE,
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
StandardFunctionReturnTypeResolvers.invariant(
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
),
new Validator(),
new ReturnTypeResolver( typeConfiguration ),
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, NUMERIC )
);
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
Expand Down Expand Up @@ -131,4 +142,116 @@ public String getArgumentListSignature() {
return "(NUMERIC arg)";
}

public static class Validator implements ArgumentsValidator {

public static final ArgumentsValidator INSTANCE = new Validator();

@Override
public void validate(
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
if ( arguments.size() != 1 ) {
throw new FunctionArgumentException(
String.format(
Locale.ROOT,
"Function %s() has %d parameters, but %d arguments given",
functionName,
1,
arguments.size()
)
);
}
final SqmTypedNode<?> argument = arguments.get( 0 );
final SqmExpressible<?> expressible = argument.getExpressible();
final DomainType<?> domainType;
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
final JdbcType jdbcType = getJdbcType( domainType, typeConfiguration );
if ( !isNumeric( jdbcType ) ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
1,
functionName,
NUMERIC,
domainType.getTypeName()
)
);
}
}
}

private static boolean isNumeric(JdbcType jdbcType) {
final int sqlTypeCode = jdbcType.getDefaultSqlTypeCode();
if ( SqlTypes.isNumericType( sqlTypeCode ) ) {
return true;
}
if ( jdbcType instanceof ArrayJdbcType ) {
return isNumeric( ( (ArrayJdbcType) jdbcType ).getElementJdbcType() );
}
return false;
}

private static JdbcType getJdbcType(DomainType<?> domainType, TypeConfiguration typeConfiguration) {
if ( domainType instanceof JdbcMapping ) {
return ( (JdbcMapping) domainType ).getJdbcType();
}
else {
final JavaType<?> javaType = domainType.getExpressibleJavaType();
if ( javaType.getJavaTypeClass().isEnum() ) {
// we can't tell if the enum is mapped STRING or ORDINAL
return ObjectJdbcType.INSTANCE;
}
else {
return javaType.getRecommendedJdbcType( typeConfiguration.getCurrentBaseSqlTypeIndicators() );
}
}
}

@Override
public String getSignature() {
return "(arg)";
}
}

public static class ReturnTypeResolver implements FunctionReturnTypeResolver {

private final BasicType<Double> doubleType;

public ReturnTypeResolver(TypeConfiguration typeConfiguration) {
this.doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
}

@Override
public BasicValuedMapping resolveFunctionReturnType(
Supplier<BasicValuedMapping> impliedTypeAccess,
List<? extends SqlAstNode> arguments) {
final BasicValuedMapping impliedType = impliedTypeAccess.get();
if ( impliedType != null ) {
return impliedType;
}
final JdbcMapping jdbcMapping = ( (Expression) arguments.get( 0 ) ).getExpressionType().getSingleJdbcMapping();
if ( jdbcMapping instanceof BasicPluralType<?, ?> ) {
return (BasicValuedMapping) jdbcMapping;
}
return doubleType;
}

@Override
public ReturnableType<?> resolveFunctionReturnType(
ReturnableType<?> impliedType,
Supplier<MappingModelExpressible<?>> inferredTypeSupplier,
List<? extends SqmTypedNode<?>> arguments,
TypeConfiguration typeConfiguration) {
final SqmExpressible<?> expressible = arguments.get( 0 ).getExpressible();
final DomainType<?> domainType;
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
if ( domainType instanceof BasicPluralType<?, ?> ) {
return (ReturnableType<?>) domainType;
}
}
return typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
}
}

}

0 comments on commit eebb305

Please sign in to comment.