Skip to content

Commit

Permalink
Cache collection schema for insert/upsert
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <yihua.mo@zilliz.com>
  • Loading branch information
yhmo committed Jun 11, 2024
1 parent cdd3757 commit 6741234
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 60 deletions.
116 changes: 72 additions & 44 deletions src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import javax.annotation.Nonnull;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

Expand All @@ -62,12 +63,64 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
protected static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
protected LogLevel logLevel = LogLevel.Info;

private ConcurrentHashMap<String, DescribeCollectionResponse> cacheCollectionInfo = new ConcurrentHashMap<>();

protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();

protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();

protected abstract boolean clientIsReady();

/**
* This method is for insert/upsert requests to reduce the rpc call of describeCollection()
* Always try to get the collection info from cache.
* If the cache doesn't have the collection info, call describeCollection() and cache it.
* If insert/upsert get server error, remove the cached collection info.
*/
private DescribeCollectionResponse getCollectionInfo(String databaseName, String collectionName) {
String key = combineCacheKey(databaseName, collectionName);
DescribeCollectionResponse info = cacheCollectionInfo.get(key);
if (info == null) {
String msg = String.format("Fail to describe collection '%s'", collectionName);
DescribeCollectionRequest.Builder builder = DescribeCollectionRequest.newBuilder()
.setCollectionName(collectionName);
if (StringUtils.isNotEmpty(databaseName)) {
builder.setDbName(databaseName);
msg = String.format("Fail to describe collection '%s' in database '%s'",
collectionName, databaseName);
}
DescribeCollectionRequest describeCollectionRequest = builder.build();
DescribeCollectionResponse response = blockingStub().describeCollection(describeCollectionRequest);
handleResponse(msg, response.getStatus());
info = response;
cacheCollectionInfo.put(key, info);
}

return info;
}

private String combineCacheKey(String databaseName, String collectionName) {
if (collectionName == null || StringUtils.isBlank(collectionName)) {
throw new ParamException("Collection name is empty, not able to get collection info.");
}
String key = collectionName;
if (StringUtils.isNotEmpty(databaseName)) {
key = String.format("%s|%s", databaseName, collectionName);
}
return key;
}

/**
* insert/upsert return an error, but is not a RateLimit error,
* clean the cache so that the next insert will call describeCollection() to get the latest info.
*/
private void cleanCacheIfFailed(Status status, String databaseName, String collectionName) {
if ((status.getCode() != 0 && status.getCode() != 8) ||
(!status.getErrorCode().equals(ErrorCode.Success) && status.getErrorCode() != ErrorCode.RateLimit)) {
cacheCollectionInfo.remove(combineCacheKey(databaseName, collectionName));
}
}

private void waitForLoadingCollection(String databaseName, String collectionName, List<String> partitionNames,
long waitingInterval, long timeout) throws IllegalResponseException {
long tsBegin = System.currentTimeMillis();
Expand Down Expand Up @@ -581,6 +634,7 @@ public R<RpcStatus> dropCollection(@NonNull DropCollectionParam requestParam) {

Status response = blockingStub().dropCollection(dropCollectionRequest);
handleResponse(title, response);
cacheCollectionInfo.remove(combineCacheKey(requestParam.getDatabaseName(), requestParam.getCollectionName()));
return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG));
} catch (StatusRuntimeException e) {
logError("{} RPC failed! Exception:{}", title, e);
Expand Down Expand Up @@ -1509,17 +1563,12 @@ public R<MutationResult> insert(@NonNull InsertParam requestParam) {
String title = String.format("InsertRequest collectionName:%s", requestParam.getCollectionName());

try {
DescribeCollectionParam.Builder builder = DescribeCollectionParam.newBuilder()
.withDatabaseName(requestParam.getDatabaseName())
.withCollectionName(requestParam.getCollectionName());
R<DescribeCollectionResponse> descResp = describeCollection(builder.build());
if (descResp.getStatus() != R.Status.Success.getCode()) {
return R.failed(descResp.getException());
}

DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
DescribeCollectionResponse descResp = getCollectionInfo(requestParam.getDatabaseName(),
requestParam.getCollectionName());
DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp);
ParamUtils.InsertBuilderWrapper builderWraper = new ParamUtils.InsertBuilderWrapper(requestParam, wrapper);
MutationResult response = blockingStub().insert(builderWraper.buildInsertRequest());
cleanCacheIfFailed(response.getStatus(), requestParam.getDatabaseName(), requestParam.getCollectionName());
handleResponse(title, response.getStatus());
return R.success(response);
} catch (StatusRuntimeException e) {
Expand All @@ -1542,15 +1591,9 @@ public ListenableFuture<R<MutationResult>> insertAsync(InsertParam requestParam)
logDebug(requestParam.toString());
String title = String.format("InsertAsyncRequest collectionName:%s", requestParam.getCollectionName());

DescribeCollectionParam.Builder builder = DescribeCollectionParam.newBuilder()
.withDatabaseName(requestParam.getDatabaseName())
.withCollectionName(requestParam.getCollectionName());
R<DescribeCollectionResponse> descResp = describeCollection(builder.build());
if (descResp.getStatus() != R.Status.Success.getCode()) {
return Futures.immediateFuture(R.failed(descResp.getException()));
}

DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
DescribeCollectionResponse descResp = getCollectionInfo(requestParam.getDatabaseName(),
requestParam.getCollectionName());
DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp);
ParamUtils.InsertBuilderWrapper builderWraper = new ParamUtils.InsertBuilderWrapper(requestParam, wrapper);
ListenableFuture<MutationResult> response = futureStub().insert(builderWraper.buildInsertRequest());

Expand All @@ -1559,6 +1602,7 @@ public ListenableFuture<R<MutationResult>> insertAsync(InsertParam requestParam)
new FutureCallback<MutationResult>() {
@Override
public void onSuccess(MutationResult result) {
cleanCacheIfFailed(result.getStatus(), requestParam.getDatabaseName(), requestParam.getCollectionName());
if (result.getStatus().getErrorCode() == ErrorCode.Success) {
logDebug("{} successfully!", title);
} else {
Expand Down Expand Up @@ -1596,17 +1640,12 @@ public R<MutationResult> upsert(UpsertParam requestParam) {
String title = String.format("UpsertRequest collectionName:%s", requestParam.getCollectionName());

try {
DescribeCollectionParam.Builder builder = DescribeCollectionParam.newBuilder()
.withDatabaseName(requestParam.getDatabaseName())
.withCollectionName(requestParam.getCollectionName());
R<DescribeCollectionResponse> descResp = describeCollection(builder.build());
if (descResp.getStatus() != R.Status.Success.getCode()) {
return R.failed(descResp.getException());
}

DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
DescribeCollectionResponse descResp = getCollectionInfo(requestParam.getDatabaseName(),
requestParam.getCollectionName());
DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp);
ParamUtils.InsertBuilderWrapper builderWraper = new ParamUtils.InsertBuilderWrapper(requestParam, wrapper);
MutationResult response = blockingStub().upsert(builderWraper.buildUpsertRequest());
cleanCacheIfFailed(response.getStatus(), requestParam.getDatabaseName(), requestParam.getCollectionName());
handleResponse(title, response.getStatus());
return R.success(response);
} catch (StatusRuntimeException e) {
Expand All @@ -1628,15 +1667,9 @@ public ListenableFuture<R<MutationResult>> upsertAsync(UpsertParam requestParam)
logDebug(requestParam.toString());
String title = String.format("UpsertAsyncRequest collectionName:%s", requestParam.getCollectionName());

DescribeCollectionParam.Builder builder = DescribeCollectionParam.newBuilder()
.withDatabaseName(requestParam.getDatabaseName())
.withCollectionName(requestParam.getCollectionName());
R<DescribeCollectionResponse> descResp = describeCollection(builder.build());
if (descResp.getStatus() != R.Status.Success.getCode()) {
return Futures.immediateFuture(R.failed(descResp.getException()));
}

DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
DescribeCollectionResponse descResp = getCollectionInfo(requestParam.getDatabaseName(),
requestParam.getCollectionName());
DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp);
ParamUtils.InsertBuilderWrapper builderWraper = new ParamUtils.InsertBuilderWrapper(requestParam, wrapper);
ListenableFuture<MutationResult> response = futureStub().upsert(builderWraper.buildUpsertRequest());

Expand All @@ -1645,6 +1678,7 @@ public ListenableFuture<R<MutationResult>> upsertAsync(UpsertParam requestParam)
new FutureCallback<MutationResult>() {
@Override
public void onSuccess(MutationResult result) {
cleanCacheIfFailed(result.getStatus(), requestParam.getDatabaseName(), requestParam.getCollectionName());
if (result.getStatus().getErrorCode() == ErrorCode.Success) {
logDebug("{} successfully!", title);
} else {
Expand Down Expand Up @@ -3088,14 +3122,8 @@ public R<DeleteResponse> delete(DeleteIdsParam requestParam) {
String title = String.format("DeleteIdsRequest collectionName:%s", requestParam.getCollectionName());

try {
DescribeCollectionParam.Builder builder = DescribeCollectionParam.newBuilder()
.withCollectionName(requestParam.getCollectionName());
R<DescribeCollectionResponse> descResp = describeCollection(builder.build());
if (descResp.getStatus() != R.Status.Success.getCode()) {
logError("Failed to describe collection: {}", requestParam.getCollectionName());
return R.failed(descResp.getException());
}
DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
DescribeCollectionResponse descResp = getCollectionInfo("", requestParam.getCollectionName());
DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp);

String expr = VectorUtils.convertPksExpr(requestParam.getPrimaryIds(), wrapper);
DeleteParam deleteParam = DeleteParam.newBuilder()
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/io/milvus/client/MilvusClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,16 @@ default void close() {
R<ListDatabasesResponse> listDatabases();

/**
* Alter database with key value pair
* Alter database with key value pair. (Available from Milvus v2.4.4)
*
* @param requestParam {@link AlterDatabaseParam}
* @return {status:result code, data:RpcStatus{msg: result message}}
*/
R<RpcStatus> alterDatabase(AlterDatabaseParam requestParam);

/**
* show detail of database base, such as replica number and resource groups
* Show detail of database base, such as replica number and resource groups. (Available from Milvus v2.4.4)
*
* @param requestParam {@link DescribeDatabaseParam}
* @return {status:result code, data:DescribeDatabaseResponse{replica_number,resource_groups}}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ public DescribeCollectionResp describeCollection(MilvusServiceGrpc.MilvusService
.build();
DescribeCollectionResponse response = milvusServiceBlockingStub.describeCollection(describeCollectionRequest);
rpcUtils.handleResponse(title, response.getStatus());
return convertDescCollectionResp(response);
}

public static DescribeCollectionResp convertDescCollectionResp(DescribeCollectionResponse response) {
DescribeCollectionResp describeCollectionResp = DescribeCollectionResp.builder()
.collectionName(response.getCollectionName())
.description(response.getSchema().getDescription())
Expand All @@ -195,7 +199,6 @@ public DescribeCollectionResp describeCollection(MilvusServiceGrpc.MilvusService
.primaryFieldName(response.getSchema().getFieldsList().stream().filter(FieldSchema::getIsPrimaryKey).map(FieldSchema::getName).collect(java.util.stream.Collectors.toList()).get(0))
.createTime(response.getCreatedTimestamp())
.build();

return describeCollectionResp;
}

Expand Down
Loading

0 comments on commit 6741234

Please sign in to comment.