Skip to content

Commit

Permalink
Fix upsert points operation
Browse files Browse the repository at this point in the history
  • Loading branch information
msmygit committed May 8, 2024
1 parent c64fd45 commit fc7cf1a
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .run/qdrant_upsert_points_glove_25.run.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<option name="useCurrentConnection" value="false" />
</extension>
<option name="JAR_PATH" value="$PROJECT_DIR$/nb5/target/nb5.jar" />
<option name="PROGRAM_PARAMETERS" value="qdrant_vectors_live qdrant_vectors.rampup dimensions=25 testsize=10000 trainsize=1183514 train_threads=AUTO train_cycles=5..10 dataset=glove-25-angular filetype=hdf5 collection=glove_25 similarity_function=1 qdranthost=ded78a51-8370-47d8-adb0-6147f0fcbba2.us-east4-0.gcp.cloud.qdrant.io token_file=./apikey grpc_port=6334 --progress console:1s -v --add-labels &quot;dimensions:25,dataset=glove-25&quot; --show-stacktraces --logs-max 5" />
<option name="PROGRAM_PARAMETERS" value="qdrant_vectors_live qdrant_vectors.rampup dimensions=25 testsize=10000 trainsize=1183514 dataset=glove-25-angular filetype=hdf5 collection=glove_25 similarity_function=1 qdranthost=ded78a51-8370-47d8-adb0-6147f0fcbba2.us-east4-0.gcp.cloud.qdrant.io token_file=./apikey grpc_port=6334 --progress console:1s -v --add-labels &quot;dimensions:25,dataset=glove-25&quot; --show-stacktraces --logs-max 5" />
<option name="WORKING_DIRECTORY" value="$ProjectFileDir$/local/qdrant" />
<option name="ALTERNATIVE_JRE_PATH" value="jdk21" />
<method v="2" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.nosqlbench.adapter.qdrant.ops.QdrantUpsertPointsOp;
import io.nosqlbench.adapters.api.activityimpl.OpDispenser;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import io.nosqlbench.nb.api.errors.OpConfigError;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.ValueFactory;
import io.qdrant.client.VectorFactory;
Expand All @@ -30,14 +31,13 @@
import io.qdrant.client.grpc.JsonWithInt.NullValue;
import io.qdrant.client.grpc.JsonWithInt.Struct;
import io.qdrant.client.grpc.JsonWithInt.Value;
import io.qdrant.client.grpc.Points.Vector;
import io.qdrant.client.grpc.Points.*;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.LongFunction;

public class QdrantUpsertPointsOpDispenser extends QdrantBaseOpDispenser<UpsertPoints> {
Expand Down Expand Up @@ -65,92 +65,119 @@ public LongFunction<UpsertPoints> getParamFunc(

// set wait and ordering query params
ebF = op.enhanceFuncOptionally(ebF, "wait", Boolean.class, UpsertPoints.Builder::setWait);
WriteOrdering.Builder writeOrdering = WriteOrdering.newBuilder();
op.getOptionalStaticValue("ordering", Number.class)
.ifPresent((Number ordering) -> {
writeOrdering.setType(WriteOrderingType.forNumber(ordering.intValue()));
});
final LongFunction<UpsertPoints.Builder> orderingF = ebF;
ebF = l -> orderingF.apply(l).setOrdering(writeOrdering);
ebF = op.enhanceFuncOptionally(ebF, "ordering", Number.class, (UpsertPoints.Builder b, Number n) ->
b.setOrdering(WriteOrdering.newBuilder().setType(WriteOrderingType.forNumber(n.intValue()))));

// request body begins here
ShardKeySelector.Builder shardKeySelector = ShardKeySelector.newBuilder();
op.getOptionalStaticValue("shard_key", Number.class)
.ifPresent((Number value) -> {
shardKeySelector.setShardKeys(0, Collections.ShardKey.newBuilder().setNumber(value.longValue()));
});

List<PointStruct> allPoints = buildPointsStructWithNamedVectors(op);
ebF = op.enhanceFuncOptionally(ebF, "shard_key", Number.class, (UpsertPoints.Builder b, Number n) ->
b.setShardKeySelector(
ShardKeySelector.newBuilder().setShardKeys(
0, Collections.ShardKey.newBuilder().setNumber(n.longValue()))));
LongFunction<List<PointStruct>> pointsF = constructVectorPointsFunc(op);
final LongFunction<UpsertPoints.Builder> pointsOfNamedVectorsF = ebF;
ebF = l -> pointsOfNamedVectorsF.apply(l).addAllPoints(allPoints);
ebF = l -> pointsOfNamedVectorsF.apply(l).addAllPoints(pointsF.apply(l));

final LongFunction<UpsertPoints.Builder> lastF = ebF;
return l -> lastF.apply(l).build();
}

private List<PointStruct> buildPointsStructWithNamedVectors(ParsedOp op) {
List<PointStruct> allPoints = new ArrayList<>();
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
/**
* @param op the {@link ParsedOp} from which the vector objects will be built
* @return an Iterable Collection of {@link PointStruct} objects to be added to a Qdrant {@link UpsertPoints} request.
* <p>
* This method interrogates the subsection of the ParsedOp defined for vector parameters and constructs a list of
* vector (dense plus sparse) points based on the included values, or returns null if this section is not populated.
* The base function returns either the List of vectors or null, while the interior function builds the vectors
* with a Builder pattern based on the values contained in the source ParsedOp.
*/
private LongFunction<List<PointStruct>> constructVectorPointsFunc(ParsedOp op) {
Optional<LongFunction<List>> baseFunc =
op.getAsOptionalFunction("points", List.class);
return baseFunc.<LongFunction<List<PointStruct>>>map(listLongFunction -> l -> {
List<PointStruct> returnVectorPoints = new ArrayList<>();
List<Map<String, Object>> vectorPoints = listLongFunction.apply(l);
PointStruct.Builder pointBuilder;
for (Map<String, Object> point : vectorPoints) {
pointBuilder = PointStruct.newBuilder();
// 'id' field is mandatory, if not present, server will throw an exception
PointId.Builder pointId = PointId.newBuilder();
if (point.get("id") instanceof Number) {
pointId.setNum(((Number) point.get("id")).longValue());
} else if (point.get("id") instanceof String) {
pointId.setUuid((String) point.get("id"));
} else {
logger.warn("Unsupported 'id' value type [{}] specified for 'points'. Ignoring.",
point.get("id").getClass().getSimpleName());
}
pointBuilder.setId(pointId);
pointBuilder.putAllPayload(getPayloadValues(point.get("payload")));
pointBuilder.setVectors(VectorsFactory.namedVectors(getNamedVectorMap(point.get("vector"))));
returnVectorPoints.add(pointBuilder.build());
}
return returnVectorPoints;
}).orElse(null);
}

PointId.Builder pointId = PointId.newBuilder();
// id is mandatory
Object idObject = op.getAsRequiredFunction("id", Object.class).apply(0L);
if (idObject instanceof Number) {
pointId.setNum(((Number) idObject).longValue());
} else if (idObject instanceof String) {
pointId.setUuid((String) idObject);
private Map<String, Vector> getNamedVectorMap(Object rawVectorValues) {
Map<String, Vector> namedVectorMapData;
if (rawVectorValues instanceof Map) {
namedVectorMapData = new HashMap<>();
List<Float> sparseVectors = new ArrayList<>();
List<Integer> sparseIndices = new ArrayList<>();
BiConsumer<String, Object> namedVectorsToPointsVectorValue = (nvkey, nvVal) -> {
Vector targetVectorVal;
if (nvVal instanceof Map) {
// Deal with named sparse vectors here
((Map<String, Object>) nvVal).forEach(
(svKey, svValue) -> {
if ("values".equals(svKey)) {
sparseVectors.addAll((List<Float>) svValue);
} else if ("indices".equals(svKey)) {
sparseIndices.addAll((List<Integer>) svValue);
} else {
logger.warn("Unrecognized sparse vector field [{}] provided. Ignoring.", svKey);
}
}
);
targetVectorVal = VectorFactory.vector(sparseVectors, sparseIndices);
} else if (nvVal instanceof List) {
// Deal with regular named dense vectors here
targetVectorVal = VectorFactory.vector((List<Float>) nvVal);
} else
throw new RuntimeException("Unsupported 'vector' value type [" + nvVal.getClass().getSimpleName() + " ]");
namedVectorMapData.put(nvkey, targetVectorVal);
};
((Map<String, Object>) rawVectorValues).forEach(namedVectorsToPointsVectorValue);
} else {
throw new OpConfigError("Invalid format of type" +
" [" + rawVectorValues.getClass().getSimpleName() + "] specified for 'vector'");
}
pointBuilder.setId(pointId);
return namedVectorMapData;
}

if (op.isDefined("payload")) {
LongFunction<Map> payloadMapF = op.getAsRequiredFunction("payload", Map.class);
Map<String, Value> payloadMapData = new HashMap<>();
payloadMapF.apply(0L).forEach((pKey, pVal) -> {
if(pVal instanceof Boolean) {
payloadMapData.put((String) pKey, ValueFactory.value((Boolean) pVal));
} else if(pVal instanceof Double) {
payloadMapData.put((String) pKey, ValueFactory.value((Double) pVal));
} else if(pVal instanceof Integer) {
payloadMapData.put((String) pKey, ValueFactory.value((Integer) pVal));
} else if(pVal instanceof String) {
payloadMapData.put((String) pKey, ValueFactory.value((String) pVal));
} else if(pVal instanceof ListValue) {
payloadMapData.put((String) pKey, ValueFactory.list((List<Value>) pVal));
} else if(pVal instanceof NullValue) {
payloadMapData.put((String) pKey, ValueFactory.nullValue());
} else if(pVal instanceof Struct) {
payloadMapData.put((String) pKey, Value.newBuilder().setStructValue((Struct) pVal).build());
} else {
logger.warn("Unknown payload type passed." +
private Map<String, Value> getPayloadValues(Object rawPayloadValues) {
if (rawPayloadValues instanceof Map) {
Map<String, Object> payloadMap = (Map<String, Object>) rawPayloadValues;
Map<String, Value> payloadMapData = new HashMap<>(payloadMap.size());
payloadMap.forEach((pKey, pVal) -> {
switch (pVal) {
case Boolean b -> payloadMapData.put(pKey, ValueFactory.value(b));
case Double v -> payloadMapData.put(pKey, ValueFactory.value(v));
case Integer i -> payloadMapData.put(pKey, ValueFactory.value(i));
case String s -> payloadMapData.put(pKey, ValueFactory.value(s));
case ListValue listValue -> payloadMapData.put(pKey, ValueFactory.list((List<Value>) pVal));
case NullValue nullValue -> payloadMapData.put(pKey, ValueFactory.nullValue());
case Struct struct -> payloadMapData.put(pKey, Value.newBuilder().setStructValue(struct).build());
default -> logger.warn("Unknown payload value type passed." +
" Only https://qdrant.tech/documentation/concepts/payload/#payload-types are supported." +
" {} will be inored.", pVal.toString());
" {} will be ignored.", pVal.toString());
}
});
pointBuilder.putAllPayload(payloadMapData);
return payloadMapData;
} else {
throw new RuntimeException("Invalid format of type" +
" [" + rawPayloadValues.getClass().getSimpleName() + "] specified for payload");
}

LongFunction<Map> namedVectorMapF = op.getAsRequiredFunction("vector", Map.class);
Map<String, Vector> namedVectorMapData = new HashMap<>();
List<Float> sparseVectors = new ArrayList<>();
List<Integer> sparseIndices = new ArrayList<>();
namedVectorMapF.apply(0L).forEach((nvKey, nvVal) -> {
if (nvVal instanceof Map) {
// we deal with named sparse vectors here
Map<String, Object> namedSparseVectorsMap = (Map<String, Object>) nvVal;
if (namedSparseVectorsMap.containsKey("indices") && namedSparseVectorsMap.containsKey("values")) {
sparseVectors.addAll((List<Float>) namedSparseVectorsMap.get("values"));
sparseIndices.addAll((List<Integer>) namedSparseVectorsMap.get("indices"));
}
namedVectorMapData.put((String) nvKey, VectorFactory.vector(sparseVectors, sparseIndices));
} else {
// Deal with regular named dense vectors here
namedVectorMapData.put((String) nvKey, VectorFactory.vector((List<Float>) nvVal));
}
});
pointBuilder.setVectors(VectorsFactory.namedVectors(namedVectorMapData));
allPoints.add(pointBuilder.build());

return allPoints;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public Object applyOp(long value) {
String responseStatus;
long responseOperationId;
try {
logger.debug("Cycle {} has Request: {}", value, request.toString());
response = client.upsertAsync(request).get();
responseStatus = response.getStatus().toString();
responseOperationId = response.getOperationId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ blocks:
# 0 - Weak, 1 - Medium, 2 - Strong
ordering: TEMPLATE(upsert_point_ordering,1)
#shard_key: "{row_key}"
id: "{id_val}"
payload:
key: "{row_key}"
vector:
# For dense vectors, use the below format
value: "{train_floatlist_TEMPLATE(filetype)}"
# For sparse vectors, use the below format
#value_sv:
# indices: your array of numbers
# values: your array of floats

points:
- id: "{id_val}"
payload:
key: "{row_key}"
vector:
# For dense vectors, use the below format
value: "{train_floatlist_TEMPLATE(filetype)}"
# For sparse vectors, use the below format
#value_sv:
# indices: your array of numbers
# values: your array of floats

search_points:
ops:
Expand Down

0 comments on commit fc7cf1a

Please sign in to comment.