Skip to content

Commit

Permalink
Verify retrieve
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <yihua.mo@zilliz.com>
  • Loading branch information
yhmo committed Jun 4, 2024
1 parent 1554910 commit be258d0
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/main/java/io/milvus/v1/CommonUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ public static ByteBuffer generateFloat16Vector(int dimension, boolean bfloat16)
for (int i = 0; i < dimension; ++i) {
ByteDataBuffer bf;
if (bfloat16) {
TFloat16 tt = TFloat16.scalarOf((float)ran.nextInt(dimension));
TBfloat16 tt = TBfloat16.scalarOf((float)ran.nextInt(dimension));
bf = tt.asRawTensor().data();
} else {
TBfloat16 tt = TBfloat16.scalarOf((float)ran.nextInt(dimension));
TFloat16 tt = TFloat16.scalarOf((float)ran.nextInt(dimension));
bf = tt.asRawTensor().data();
}
vector.put(bf.getByte(0));
Expand Down
90 changes: 90 additions & 0 deletions examples/main/java/io/milvus/v1/Float16VectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
import java.nio.ByteBuffer;
import java.util.*;

import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TFloat16;


public class Float16VectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_float16";
Expand Down Expand Up @@ -202,6 +209,44 @@ private static void testFloat16(boolean bfloat16) {
}
System.out.println("Query result is correct");

// insert a single row
JsonObject row = new JsonObject();
row.addProperty(ID_FIELD, 9999999);
List<Float> newVector = CommonUtils.generateFloatVector(VECTOR_DIM);
ByteBuffer vector16Buf = encodeTF(newVector, bfloat16);
row.add(VECTOR_FIELD, gson.toJsonTree(vector16Buf.array()));
insertR = milvusClient.insert(InsertParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withRows(Collections.singletonList(row))
.build());
CommonUtils.handleResponseStatus(insertR);

// retrieve the single row
queryR = milvusClient.query(QueryParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withExpr("id == 9999999")
.addOutField(VECTOR_FIELD)
.build());
CommonUtils.handleResponseStatus(queryR);
queryWrapper = new QueryResultsWrapper(queryR.getData());
field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
r = field.getFieldData();
if (r.isEmpty()) {
throw new RuntimeException("The retrieve result is empty");
} else {
ByteBuffer outBuf = (ByteBuffer) r.get(0);
List<Float> outVector = decodeTF(outBuf, bfloat16);
if (outVector.size() != newVector.size()) {
throw new RuntimeException("The retrieve result is incorrect");
}
for (int i = 0; i < outVector.size(); i++) {
if (!isFloat16Eauql(outVector.get(i), newVector.get(i), bfloat16)) {
throw new RuntimeException("The retrieve result is incorrect");
}
}
}
System.out.println("Retrieve result is correct");

// drop the collection if you don't need the collection anymore
milvusClient.dropCollection(DropCollectionParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
Expand All @@ -211,6 +256,51 @@ private static void testFloat16(boolean bfloat16) {
milvusClient.close();
}

private static ByteBuffer encodeTF(List<Float> vector, boolean bfloat16) {
ByteBuffer buf = ByteBuffer.allocate(vector.size() * 2);
for (Float value : vector) {
ByteDataBuffer bf;
if (bfloat16) {
TBfloat16 tt = TBfloat16.scalarOf(value);
bf = tt.asRawTensor().data();
} else {
TFloat16 tt = TFloat16.scalarOf(value);
bf = tt.asRawTensor().data();
}
buf.put(bf.getByte(0));
buf.put(bf.getByte(1));
}
return buf;
}

private static List<Float> decodeTF(ByteBuffer buf, boolean bfloat16) {
int dim = buf.limit()/2;
ByteDataBuffer bf = DataBuffers.of(buf.array());
List<Float> vec = new ArrayList<>();
if (bfloat16) {
TBfloat16 tf = Tensor.of(TBfloat16.class, Shape.of(dim), bf);
for (long k = 0; k < tf.size(); k++) {
vec.add(tf.getFloat(k));
}
} else {
TFloat16 tf = Tensor.of(TFloat16.class, Shape.of(dim), bf);
for (long k = 0; k < tf.size(); k++) {
vec.add(tf.getFloat(k));
}
}

return vec;
}

private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
if (bfloat16) {
return Math.abs(a - b) <= 0.01f;
} else {
return Math.abs(a - b) <= 0.001f;
}
}


public static void main(String[] args) {
testFloat16(true);
testFloat16(false);
Expand Down
81 changes: 77 additions & 4 deletions src/test/java/io/milvus/client/MilvusClientDockerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,27 @@ void testFloatVectors() {
ShowPartResponseWrapper infoPart = new ShowPartResponseWrapper(showPartR.getData());
System.out.println("Partition info: " + infoPart.toString());

// query
Long fetchID = ids.get(0);
List<Float> fetchVector = vectors.get(0);
R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
.withCollectionName(randomCollectionName)
.withExpr(String.format("%s == %d", field1Name, fetchID))
.addOutField(field2Name)
.build());
Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
FieldDataWrapper fetchField = fetchWrapper.getFieldWrapper(field2Name);
Assertions.assertEquals(1L, fetchField.getRowCount());
List<?> fetchObj = fetchField.getFieldData();
Assertions.assertEquals(1, fetchObj.size());
Assertions.assertInstanceOf(List.class, fetchObj.get(0));
List<Float> fetchResult = (List<Float>) fetchObj.get(0);
Assertions.assertEquals(fetchVector.size(), fetchResult.size());
for (int i = 0; i < fetchResult.size(); i++) {
Assertions.assertEquals(fetchVector.get(i), fetchResult.get(i));
}

// query vectors to verify
List<Long> queryIDs = new ArrayList<>();
List<Double> compareWeights = new ArrayList<>();
Expand Down Expand Up @@ -450,6 +471,7 @@ void testFloatVectors() {
.withVectorFieldName(field2Name)
.withParams("{\"ef\":64}")
.addOutField(field4Name)
.addOutField(field2Name)
.build();

R<SearchResults> searchR = client.search(searchParam);
Expand All @@ -462,7 +484,15 @@ void testFloatVectors() {
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
System.out.println(scores);
Assertions.assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
Assertions.assertEquals(targetVectorIDs.get(i), scores.get(0).getLongID());

Object obj = scores.get(0).get(field2Name);
Assertions.assertInstanceOf(List.class, obj);
List<Float> outputVec = (List<Float>)obj;
Assertions.assertEquals(targetVectors.get(i).size(), outputVec.size());
for (int k = 0; k < outputVec.size(); k++) {
Assertions.assertEquals(targetVectors.get(i).get(k), outputVec.get(k));
}
}

List<?> fieldData = results.getFieldData(field4Name, 0);
Expand Down Expand Up @@ -597,6 +627,24 @@ void testBinaryVectors() {
.build());
Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());

// query
Long fetchID = ids1.get(0);
ByteBuffer fetchVector = vectors.get(0);
R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
.withCollectionName(randomCollectionName)
.withExpr(String.format("%s == %d", field1Name, fetchID))
.addOutField(field2Name)
.build());
Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
FieldDataWrapper fetchField = fetchWrapper.getFieldWrapper(field2Name);
Assertions.assertEquals(1L, fetchField.getRowCount());
List<?> fetchObj = fetchField.getFieldData();
Assertions.assertEquals(1, fetchObj.size());
Assertions.assertInstanceOf(ByteBuffer.class, fetchObj.get(0));
ByteBuffer fetchBuffer = (ByteBuffer) fetchObj.get(0);
Assertions.assertArrayEquals(fetchVector.array(), fetchBuffer.array());

// search with BIN_FLAT index
int searchTarget = 99;
List<ByteBuffer> oneVector = new ArrayList<>();
Expand All @@ -623,9 +671,7 @@ void testBinaryVectors() {
List<?> items = oneResult.getFieldData(field2Name, 0);
Assertions.assertEquals(items.size(), 5);
ByteBuffer firstItem = (ByteBuffer) items.get(0);
for (int i = 0; i < firstItem.limit(); ++i) {
Assertions.assertEquals(firstItem.get(i), vectors.get(searchTarget).get(i));
}
Assertions.assertArrayEquals(vectors.get(searchTarget).array(), firstItem.array());

// release collection
ReleaseCollectionParam releaseCollectionParam = ReleaseCollectionParam.newBuilder()
Expand Down Expand Up @@ -773,6 +819,28 @@ void testSparseVector() {
.build());
Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());

// query
Long fetchID = ids.get(0);
SortedMap<Long, Float> fetchVector = vectors.get(0);
R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
.withCollectionName(randomCollectionName)
.withExpr(String.format("%s == %d", field1Name, fetchID))
.addOutField(field2Name)
.build());
Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
FieldDataWrapper fetchField = fetchWrapper.getFieldWrapper(field2Name);
Assertions.assertEquals(1L, fetchField.getRowCount());
List<?> fetchObj = fetchField.getFieldData();
Assertions.assertEquals(1, fetchObj.size());
Assertions.assertInstanceOf(SortedMap.class, fetchObj.get(0));
SortedMap<Long, Float> fetchSparse = (SortedMap<Long, Float>) fetchObj.get(0);
Assertions.assertEquals(fetchVector.size(), fetchSparse.size());
for (Long key : fetchVector.keySet()) {
Assertions.assertTrue(fetchSparse.containsKey(key));
Assertions.assertEquals(fetchVector.get(key), fetchSparse.get(key));
}

// pick some vectors to search with index
int nq = 5;
List<Long> targetVectorIDs = new ArrayList<>();
Expand Down Expand Up @@ -813,6 +881,11 @@ void testSparseVector() {
Object v = scores.get(0).get(field2Name);
SortedMap<Long, Float> sparse = (SortedMap<Long, Float>)v;
Assertions.assertTrue(sparse.equals(targetVectors.get(i)));
Assertions.assertEquals(targetVectors.get(i).size(), sparse.size());
for (Long key : sparse.keySet()) {
Assertions.assertTrue(targetVectors.get(i).containsKey(key));
Assertions.assertEquals(sparse.get(key), targetVectors.get(i).get(key));
}
}

// drop collection
Expand Down

0 comments on commit be258d0

Please sign in to comment.