Skip to content

Commit

Permalink
Iterator for V2 (#942)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <yihua.mo@zilliz.com>
  • Loading branch information
yhmo committed Jun 26, 2024
1 parent 5c73727 commit 15d6e00
Show file tree
Hide file tree
Showing 22 changed files with 713 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ volumes/
*.iml

# Example files
examples/bulk_writer
examples/main/resources/tls/*
!examples/main/resources/tls/gen.sh
!examples/main/resources/tls/openssl.cnf
Expand Down
2 changes: 1 addition & 1 deletion examples/main/java/io/milvus/v1/BinaryVectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import java.util.*;

public class BinaryVectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_sparse";
private static final String COLLECTION_NAME = "java_sdk_example_binary_vector";
private static final String ID_FIELD = "id";
private static final String VECTOR_FIELD = "vector";

Expand Down
2 changes: 1 addition & 1 deletion examples/main/java/io/milvus/v1/Float16VectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


public class Float16VectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_float16";
private static final String COLLECTION_NAME = "java_sdk_example_float16_vector";
private static final String ID_FIELD = "id";
private static final String VECTOR_FIELD = "vector";
private static final Integer VECTOR_DIM = 128;
Expand Down
2 changes: 1 addition & 1 deletion examples/main/java/io/milvus/v1/IteratorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class IteratorExample {
milvusClient = new MilvusServiceClient(connectParam).withRetry(retryParam);
}

private static final String COLLECTION_NAME = "test_iterator";
private static final String COLLECTION_NAME = "java_sdk_example_iterator";
private static final String ID_FIELD = "userID";
private static final String VECTOR_FIELD = "userFace";
private static final Integer VECTOR_DIM = 8;
Expand Down
2 changes: 1 addition & 1 deletion examples/main/java/io/milvus/v1/SparseVectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


public class SparseVectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_sparse";
private static final String COLLECTION_NAME = "java_sdk_example_sparse_vector";
private static final String ID_FIELD = "id";
private static final String VECTOR_FIELD = "vector";

Expand Down
2 changes: 1 addition & 1 deletion examples/main/java/io/milvus/v2/Float16VectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


public class Float16VectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_float16";
private static final String COLLECTION_NAME = "java_sdk_example_float16_vector";
private static final String ID_FIELD = "id";
private static final String FP16_VECTOR_FIELD = "fp16_vector";
private static final String BF16_VECTOR_FIELD = "bf16_vector";
Expand Down
161 changes: 161 additions & 0 deletions examples/main/java/io/milvus/v2/IteratorExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package io.milvus.v2;

import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import io.milvus.orm.iterator.QueryIterator;
import io.milvus.orm.iterator.SearchIterator;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.v1.CommonUtils;
import io.milvus.v2.client.ConnectConfig;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.QueryIteratorReq;
import io.milvus.v2.service.vector.request.QueryReq;
import io.milvus.v2.service.vector.request.SearchIteratorReq;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.QueryResp;

import java.util.*;

public class IteratorExample {
private static final String COLLECTION_NAME = "java_sdk_example_iterator";
private static final String ID_FIELD = "userID";
private static final String AGE_FIELD = "userAge";
private static final String VECTOR_FIELD = "userFace";
private static final Integer VECTOR_DIM = 128;

public static void main(String[] args) {
ConnectConfig config = ConnectConfig.builder()
.uri("http://localhost:19530")
.build();
MilvusClientV2 client = new MilvusClientV2(config);

// create collection
CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
.build();
collectionSchema.addField(AddFieldReq.builder()
.fieldName(ID_FIELD)
.dataType(DataType.Int64)
.isPrimaryKey(Boolean.TRUE)
.autoID(Boolean.FALSE)
.build());
collectionSchema.addField(AddFieldReq.builder()
.fieldName(AGE_FIELD)
.dataType(DataType.Int32)
.build());
collectionSchema.addField(AddFieldReq.builder()
.fieldName(VECTOR_FIELD)
.dataType(DataType.FloatVector)
.dimension(VECTOR_DIM)
.build());

List<IndexParam> indexParams = new ArrayList<>();
indexParams.add(IndexParam.builder()
.fieldName(VECTOR_FIELD)
.indexType(IndexParam.IndexType.FLAT)
.metricType(IndexParam.MetricType.L2)
.build());

CreateCollectionReq requestCreate = CreateCollectionReq.builder()
.collectionName(COLLECTION_NAME)
.collectionSchema(collectionSchema)
.indexParams(indexParams)
.build();
client.createCollection(requestCreate);

// insert rows
long count = 10000;
List<JsonObject> rowsData = new ArrayList<>();
Random ran = new Random();
Gson gson = new Gson();
for (long i = 0L; i < count; ++i) {
JsonObject row = new JsonObject();
row.addProperty(ID_FIELD, i);
row.addProperty(AGE_FIELD, ran.nextInt(99));
row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));

rowsData.add(row);
}
InsertResp insertResp = client.insert(InsertReq.builder()
.collectionName(COLLECTION_NAME)
.data(rowsData)
.build());

// check row count
QueryResp queryResp = client.query(QueryReq.builder()
.collectionName(COLLECTION_NAME)
.filter("")
.outputFields(Collections.singletonList("count(*)"))
.consistencyLevel(ConsistencyLevel.STRONG)
.build());
List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
System.out.printf("Inserted row count: %d\n", queryResults.size());

// search iterator
SearchIterator searchIterator = client.searchIterator(SearchIteratorReq.builder()
.collectionName(COLLECTION_NAME)
.outputFields(Lists.newArrayList(AGE_FIELD))
.batchSize(50L)
.vectorFieldName(VECTOR_FIELD)
.vectors(Collections.singletonList(new FloatVec(CommonUtils.generateFloatVector(VECTOR_DIM))))
.expr(String.format("%s > 50 && %s < 100", AGE_FIELD, AGE_FIELD))
.params("{\"range_filter\": 15.0, \"radius\": 20.0}")
.topK(300)
.metricType(IndexParam.MetricType.L2)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build());

int counter = 0;
while (true) {
List<QueryResultsWrapper.RowRecord> res = searchIterator.next();
if (res.isEmpty()) {
System.out.println("Search iteration finished, close");
searchIterator.close();
break;
}

for (QueryResultsWrapper.RowRecord record : res) {
System.out.println(record);
counter++;
}
}
System.out.println(String.format("%d search results returned\n", counter));

// query iterator
QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder()
.collectionName(COLLECTION_NAME)
.expr(String.format("%s < 300", ID_FIELD))
.outputFields(Lists.newArrayList(ID_FIELD, AGE_FIELD))
.batchSize(50L)
.offset(5)
.limit(400)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build());

counter = 0;
while (true) {
List<QueryResultsWrapper.RowRecord> res = queryIterator.next();
if (res.isEmpty()) {
System.out.println("query iteration finished, close");
queryIterator.close();
break;
}

for (QueryResultsWrapper.RowRecord record : res) {
System.out.println(record);
counter++;
}
}
System.out.println(String.format("%d query results returned", counter));

client.dropCollection(DropCollectionReq.builder().collectionName(COLLECTION_NAME).build());
}
}
4 changes: 1 addition & 3 deletions examples/main/java/io/milvus/v2/SimpleExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.collection.request.GetCollectionStatsReq;
import io.milvus.v2.service.collection.response.GetCollectionStatsResp;
import io.milvus.v2.service.vector.request.*;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.*;
Expand All @@ -21,7 +19,7 @@ public static void main(String[] args) {
.build();
MilvusClientV2 client = new MilvusClientV2(config);

String collectionName = "simple_test";
String collectionName = "java_sdk_example_simple";
// drop collection if exists
client.dropCollection(DropCollectionReq.builder()
.collectionName(collectionName)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/io/milvus/common/utils/Float16Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public static ByteBuffer f16VectorToBuffer(List<Short> vector) {
/**
* Converts a ByteBuffer to a fp16/bf16 vector stored in short array.
*/
public static List<Short> BufferToF16Vector(ByteBuffer buf) {
public static List<Short> bufferToF16Vector(ByteBuffer buf) {
buf.rewind(); // reset the read position
List<Short> vector = new ArrayList<>();
ShortBuffer sbuf = buf.asShortBuffer();
Expand Down
130 changes: 130 additions & 0 deletions src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package io.milvus.orm.iterator;

import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.exception.ParamException;
import io.milvus.grpc.DataType;
import io.milvus.grpc.PlaceholderType;
import io.milvus.param.MetricType;
import io.milvus.param.collection.FieldType;
import io.milvus.param.dml.SearchIteratorParam;
import io.milvus.param.dml.QueryIteratorParam;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.QueryIteratorReq;
import io.milvus.v2.service.vector.request.SearchIteratorReq;
import io.milvus.v2.service.vector.request.data.BaseVector;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;

public class IteratorAdapterV2 {
public static QueryIteratorParam convertV2Req(QueryIteratorReq queryIteratorReq) {
return QueryIteratorParam.newBuilder()
.withDatabaseName(queryIteratorReq.getDatabaseName())
.withCollectionName(queryIteratorReq.getCollectionName())
.withPartitionNames(queryIteratorReq.getPartitionNames())
.withExpr(queryIteratorReq.getExpr())
.withOutFields(queryIteratorReq.getOutputFields())
.withConsistencyLevel(ConsistencyLevelEnum.valueOf(queryIteratorReq.getConsistencyLevel().name()))
.withOffset(queryIteratorReq.getOffset())
.withLimit(queryIteratorReq.getLimit())
.withIgnoreGrowing(queryIteratorReq.isIgnoreGrowing())
.withBatchSize(queryIteratorReq.getBatchSize())
.build();
}
public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorReq) {
MetricType metricType = MetricType.None;
if (searchIteratorReq.getMetricType() != IndexParam.MetricType.INVALID) {
metricType = MetricType.valueOf(searchIteratorReq.getMetricType().name());
}

SearchIteratorParam.Builder builder = SearchIteratorParam.newBuilder()
.withDatabaseName(searchIteratorReq.getDatabaseName())
.withCollectionName(searchIteratorReq.getCollectionName())
.withPartitionNames(searchIteratorReq.getPartitionNames())
.withVectorFieldName(searchIteratorReq.getVectorFieldName())
.withMetricType(metricType)
.withTopK(searchIteratorReq.getTopK())
.withExpr(searchIteratorReq.getExpr())
.withOutFields(searchIteratorReq.getOutputFields())
.withRoundDecimal(searchIteratorReq.getRoundDecimal())
.withParams(searchIteratorReq.getParams())
.withGroupByFieldName(searchIteratorReq.getGroupByFieldName())
.withIgnoreGrowing(searchIteratorReq.isIgnoreGrowing())
.withBatchSize(searchIteratorReq.getBatchSize());

if (searchIteratorReq.getConsistencyLevel() != null) {
builder.withConsistencyLevel(ConsistencyLevelEnum.valueOf(searchIteratorReq.getConsistencyLevel().name()));
}

List<BaseVector> vectors = searchIteratorReq.getVectors();
PlaceholderType plType = vectors.get(0).getPlaceholderType();
for (BaseVector vector : vectors) {
if (vector.getPlaceholderType() != plType) {
throw new ParamException("Different types of target vectors in a search request is not allowed.");
}
}

switch (plType) {
case FloatVector: {
List<List<Float>> data = new ArrayList<>();
vectors.forEach(vector->data.add((List<Float>)vector.getData()));
builder.withFloatVectors(data);
break;
}
case BinaryVector: {
List<ByteBuffer> data = new ArrayList<>();
vectors.forEach(vector->data.add((ByteBuffer)vector.getData()));
builder.withBinaryVectors(data);
break;
}
case Float16Vector: {
List<ByteBuffer> data = new ArrayList<>();
vectors.forEach(vector -> data.add((ByteBuffer)vector.getData()));
builder.withFloat16Vectors(data);
break;
}
case BFloat16Vector: {
List<ByteBuffer> data = new ArrayList<>();
vectors.forEach(vector -> data.add((ByteBuffer)vector.getData()));
builder.withBFloat16Vectors(data);
break;
}
case SparseFloatVector: {
List<SortedMap<Long, Float>> data = new ArrayList<>();
vectors.forEach(vector -> data.add((SortedMap<Long, Float>)vector.getData()));
builder.withSparseFloatVectors(data);
break;
}
default:
throw new ParamException("Unsupported vector type.");
}

return builder.build();
}

public static FieldType convertV2Field(CreateCollectionReq.FieldSchema schema) {
FieldType.Builder builder = FieldType.newBuilder()
.withName(schema.getName())
.withDataType(DataType.valueOf(schema.getDataType().name()))
.withPrimaryKey(schema.getIsPrimaryKey())
.withAutoID(schema.getAutoID())
.withPartitionKey(schema.getIsPartitionKey());

if (schema.getDimension() != null) {
builder.withDimension(schema.getDimension());
}
if (schema.getMaxLength() != null) {
builder.withMaxLength(schema.getMaxLength());
}
if (schema.getMaxCapacity() != null) {
builder.withMaxCapacity(schema.getMaxLength());
}
if (schema.getElementType() != null) {
builder.withElementType(DataType.valueOf(schema.getElementType().name()));
}
return builder.build();
}
}
Loading

0 comments on commit 15d6e00

Please sign in to comment.