Skip to content

Commit

Permalink
JAVA-3061 Re-introduce an improved CqlVector, add support for accessi…
Browse files Browse the repository at this point in the history
…ng vectors directly as float arrays (#1666)
  • Loading branch information
absurdfarce committed Jul 8, 2023
1 parent fc79bb7 commit e2fb42d
Show file tree
Hide file tree
Showing 18 changed files with 901 additions and 58 deletions.
65 changes: 64 additions & 1 deletion core/revapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -6887,7 +6887,70 @@
"code": "java.method.removed",
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"justification": "Refactoring in JAVA-3061"
}
},
{
"code": "java.class.removed",
"old": "class com.datastax.oss.driver.api.core.data.CqlVector.Builder<T>",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.removed",
"old": "method com.datastax.oss.driver.api.core.data.CqlVector.Builder com.datastax.oss.driver.api.core.data.CqlVector<T>::builder()",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.removed",
"old": "method java.lang.Iterable<T> com.datastax.oss.driver.api.core.data.CqlVector<T>::getValues()",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.generics.formalTypeParameterChanged",
"old": "class com.datastax.oss.driver.api.core.data.CqlVector<T>",
"new": "class com.datastax.oss.driver.api.core.data.CqlVector<T extends java.lang.Number>",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.parameterTypeChanged",
"old": "parameter <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.CqlVectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"new": "parameter <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.VectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.parameterTypeParameterChanged",
"old": "parameter <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>===)",
"new": "parameter <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>===)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.returnTypeTypeParametersChanged",
"old": "method <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"new": "method <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.generics.formalTypeParameterChanged",
"old": "method <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"new": "method <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.parameterTypeParameterChanged",
"old": "parameter <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType<T>===)",
"new": "parameter <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType<T>===)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.returnTypeTypeParametersChanged",
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"new": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.generics.formalTypeParameterChanged",
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"new": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"justification": "Refactorings in PR 1666"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.data;

import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import com.datastax.oss.driver.shaded.guava.common.base.Predicates;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterables;
import com.datastax.oss.driver.shaded.guava.common.collect.Streams;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Representation of a vector as defined in CQL.
*
* <p>A CQL vector is a fixed-length array of non-null numeric values. These properties don't map
* cleanly to an existing class in the standard JDK Collections hierarchy so we provide this value
* object instead. Like other value object collections returned by the driver instances of this
* class are not immutable; think of these value objects as a representation of a vector stored in
* the database as an initial step in some additional computation.
*
* <p>While we don't implement any Collection APIs we do implement Iterable. We also attempt to play
* nice with the Streams API in order to better facilitate integration with data pipelines. Finally,
* where possible we've tried to make the API of this class similar to the equivalent methods on
* {@link List}.
*/
public class CqlVector<T extends Number> implements Iterable<T> {

/**
* Create a new CqlVector containing the specified values.
*
* @param vals the collection of values to wrap.
* @return a CqlVector wrapping those values
*/
public static <V extends Number> CqlVector<V> newInstance(V... vals) {

// Note that Array.asList() guarantees the return of an array which implements RandomAccess
return new CqlVector(Arrays.asList(vals));
}

/**
* Create a new CqlVector that "wraps" an existing ArrayList. Modifications to the passed
* ArrayList will also be reflected in the returned CqlVector.
*
* @param list the collection of values to wrap.
* @return a CqlVector wrapping those values
*/
public static <V extends Number> CqlVector<V> newInstance(List<V> list) {
Preconditions.checkArgument(list != null, "Input list should not be null");
return new CqlVector(list);
}

/**
* Create a new CqlVector instance from the specified string representation. Note that this method
* is intended to mirror {@link #toString()}; passing this method the output from a <code>toString
* </code> call on some CqlVector should return a CqlVector that is equal to the origin instance.
*
* @param str a String representation of a CqlVector
* @param subtypeCodec
* @return a new CqlVector built from the String representation
*/
public static <V extends Number> CqlVector<V> from(
@NonNull String str, @NonNull TypeCodec<V> subtypeCodec) {
Preconditions.checkArgument(str != null, "Cannot create CqlVector from null string");
Preconditions.checkArgument(!str.isEmpty(), "Cannot create CqlVector from empty string");
ArrayList<V> vals =
Streams.stream(Splitter.on(", ").split(str.substring(1, str.length() - 1)))
.map(subtypeCodec::parse)
.collect(Collectors.toCollection(ArrayList::new));
return new CqlVector(vals);
}

private final List<T> list;

private CqlVector(@NonNull List<T> list) {

Preconditions.checkArgument(
Iterables.all(list, Predicates.notNull()), "CqlVectors cannot contain null values");
this.list = list;
}

/**
* Retrieve the value at the specified index. Modelled after {@link List#get(int)}
*
* @param idx the index to retrieve
* @return the value at the specified index
*/
public T get(int idx) {
return list.get(idx);
}

/**
* Update the value at the specified index. Modelled after {@link List#set(int, Object)}
*
* @param idx the index to set
* @param val the new value for the specified index
* @return the old value for the specified index
*/
public T set(int idx, T val) {
return list.set(idx, val);
}

/**
* Return the size of this vector. Modelled after {@link List#size()}
*
* @return the vector size
*/
public int size() {
return this.list.size();
}

/**
* Return a CqlVector consisting of the contents of a portion of this vector. Modelled after
* {@link List#subList(int, int)}
*
* @param from the index to start from (inclusive)
* @param to the index to end on (exclusive)
* @return a new CqlVector wrapping the sublist
*/
public CqlVector<T> subVector(int from, int to) {
return new CqlVector<T>(this.list.subList(from, to));
}

/**
* Return a boolean indicating whether the vector is empty. Modelled after {@link List#isEmpty()}
*
* @return true if the list is empty, false otherwise
*/
public boolean isEmpty() {
return this.list.isEmpty();
}

/**
* Create an {@link Iterator} for this vector
*
* @return the generated iterator
*/
@Override
public Iterator<T> iterator() {
return this.list.iterator();
}

/**
* Create a {@link Stream} of the values in this vector
*
* @return the Stream instance
*/
public Stream<T> stream() {
return this.list.stream();
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (o instanceof CqlVector) {
CqlVector that = (CqlVector) o;
return this.list.equals(that.list);
} else {
return false;
}
}

@Override
public int hashCode() {
return Objects.hash(list);
}

@Override
public String toString() {
return Iterables.toString(this.list);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ default CqlDuration getCqlDuration(@NonNull CqlIdentifier id) {
* @throws IllegalArgumentException if the id is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(
default <ElementT extends Number> CqlVector<ElementT> getVector(
@NonNull CqlIdentifier id, @NonNull Class<ElementT> elementsClass) {
return getVector(firstIndexOf(id), elementsClass);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,9 @@ default CqlDuration getCqlDuration(int i) {
* @throws IndexOutOfBoundsException if the index is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(int i, @NonNull Class<ElementT> elementsClass) {
return get(i, GenericType.listOf(elementsClass));
default <ElementT extends Number> CqlVector<ElementT> getVector(
int i, @NonNull Class<ElementT> elementsClass) {
return get(i, GenericType.vectorOf(elementsClass));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ default CqlDuration getCqlDuration(@NonNull String name) {
* @throws IllegalArgumentException if the name is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(
default <ElementT extends Number> CqlVector<ElementT> getVector(
@NonNull String name, @NonNull Class<ElementT> elementsClass) {
return getList(firstIndexOf(name), elementsClass);
return getVector(firstIndexOf(name), elementsClass);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ default SelfT setCqlDuration(@NonNull CqlIdentifier id, @Nullable CqlDuration v)
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
default <ElementT extends Number> SelfT setVector(
@NonNull CqlIdentifier id,
@Nullable List<ElementT> v,
@Nullable CqlVector<ElementT> v,
@NonNull Class<ElementT> elementsClass) {
SelfT result = null;
for (Integer i : allIndicesOf(id)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ default SelfT setCqlDuration(int i, @Nullable CqlDuration v) {
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
int i, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
return set(i, v, GenericType.listOf(elementsClass));
default <ElementT extends Number> SelfT setVector(
int i, @Nullable CqlVector<ElementT> v, @NonNull Class<ElementT> elementsClass) {
return set(i, v, GenericType.vectorOf(elementsClass));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,10 @@ default SelfT setCqlDuration(@NonNull String name, @Nullable CqlDuration v) {
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
@NonNull String name, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
default <ElementT extends Number> SelfT setVector(
@NonNull String name,
@Nullable CqlVector<ElementT> v,
@NonNull Class<ElementT> elementsClass) {
SelfT result = null;
for (Integer i : allIndicesOf(name)) {
result = (result == null ? this : result).setVector(i, v, elementsClass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
package com.datastax.oss.driver.api.core.type.codec;

import com.datastax.oss.driver.api.core.session.SessionBuilder;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.DefaultVectorType;
import com.datastax.oss.driver.internal.core.type.codec.SimpleBlobCodec;
import com.datastax.oss.driver.internal.core.type.codec.TimestampCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.OptionalCodec;
Expand All @@ -36,6 +38,7 @@
import com.datastax.oss.driver.internal.core.type.codec.extras.time.PersistentZonedTimestampCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.time.TimestampMillisCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.time.ZonedTimestampCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec;
import com.fasterxml.jackson.databind.ObjectMapper;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -479,4 +482,9 @@ public static <T> TypeCodec<T> json(
@NonNull Class<T> javaType, @NonNull ObjectMapper objectMapper) {
return new JsonCodec<>(javaType, objectMapper);
}

/** Builds a new codec that maps CQL float vectors of the specified size to an array of floats. */
public static TypeCodec<float[]> floatVectorToArray(int dimensions) {
return new FloatVectorToArrayCodec(new DefaultVectorType(DataTypes.FLOAT, dimensions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.datastax.oss.driver.api.core.type.codec;

import com.datastax.oss.driver.api.core.data.CqlDuration;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.data.TupleValue;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.type.CustomType;
Expand Down Expand Up @@ -207,12 +208,17 @@ public static TypeCodec<TupleValue> tupleOf(@NonNull TupleType cqlType) {
return new TupleCodec(cqlType);
}

public static <SubtypeT> TypeCodec<List<SubtypeT>> vectorOf(
public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
@NonNull VectorType type, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
return new VectorCodec(
DataTypes.vectorOf(subtypeCodec.getCqlType(), type.getDimensions()), subtypeCodec);
}

public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
int dimensions, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
return new VectorCodec(DataTypes.vectorOf(subtypeCodec.getCqlType(), dimensions), subtypeCodec);
}

/**
* Builds a new codec that maps a CQL user defined type to the driver's {@link UdtValue}, for the
* given type definition.
Expand Down

0 comments on commit e2fb42d

Please sign in to comment.