Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Is milvus support final? #114

Closed
AbdullahGheith opened this issue Aug 19, 2023 · 4 comments
Closed

[FEATURE] Is milvus support final? #114

AbdullahGheith opened this issue Aug 19, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@AbdullahGheith
Copy link

AbdullahGheith commented Aug 19, 2023

I've had created my own Milvus implementation but i havent shared it because i've kind of did it in a different way.
I see now that milvus has been added and maybe i should use the official one. (But i am in doubt, because its not mentioned in the changelog + i dont know if it can do what i wanna do)

The main difference is that in mine, i've added an additional id to my vectors, which is my document id.
So can save these objects seperately:

DocumentId: Book1
Text: The sky is blue

DocumentId: Book2
Text: The sky is red

Then i can do:

findRelevant("Book1", "What color is the sky", 1);

Then my milvus search will only search in all the "Book1" vectors.

I dont know if this is possible with current EmbeddingStore interface, but i havent figured a way to do it, so i made it myself.

Is this something possible to do or should i share my code?

@AbdullahGheith AbdullahGheith added the enhancement New feature or request label Aug 19, 2023
@langchain4j
Copy link
Owner

Hi @AbdullahGheith we've added support for Milvus, but there are still some problems found in the last minute before release and we decided to not announce it before we fix them. Please share your implementation, it is interesting to see what is your approach. It seems that you are talking about "metadata filtering" If I understood correctly. This is not possible yet in our implementation, but we are going to add it soon, as this is an important feature. So if you already have this, would be awesome to get your contribution! Thanks!

@AbdullahGheith
Copy link
Author

AbdullahGheith commented Aug 19, 2023

Heres my implementation. It works, and i use it currently:


public class MilvusEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {

	private final MilvusServiceClient milvusClient;
	private static final String METADATA_TEXT_SEGMENT = "text_segment";
	private static final String METADATA_TEXT_SEGMENT_ID = "text_segment_id";
	private static final String METADATA_TEXT_SEGMENT_VECTOR = "text_segment_vector";
	private final String collectionName;

	public MilvusEmbeddingStoreImpl(String host, Integer port, String collectionName, Integer dimension) {
		this.milvusClient = new MilvusServiceClient(
				ConnectParam.newBuilder()
						.withHost(host)
						.withPort(port)
						.build());
		this.collectionName = collectionName;

		if (!milvusClient.hasCollection(HasCollectionParam.newBuilder().withCollectionName(collectionName).build()).getData()) {
			R<RpcStatus> id = milvusClient.createCollection(
					CreateCollectionParam.newBuilder()
							.withCollectionName(collectionName)
							.addFieldType(FieldType.newBuilder().withName("id").withPrimaryKey(true).withDataType(DataType.Int64).withAutoID(true).build())
							.addFieldType(FieldType.newBuilder().withName(METADATA_TEXT_SEGMENT_ID).withDataType(DataType.VarChar).withMaxLength(200).build())
							.addFieldType(FieldType.newBuilder().withName(METADATA_TEXT_SEGMENT).withDataType(DataType.VarChar).withMaxLength(2000).build())
							.addFieldType(FieldType.newBuilder().withName(METADATA_TEXT_SEGMENT_VECTOR).withDataType(DataType.FloatVector).withDimension(dimension).build())
							.build()
			);
			R<RpcStatus> index = milvusClient.createIndex(
					CreateIndexParam.newBuilder()
							.withCollectionName(collectionName)
							.withFieldName(METADATA_TEXT_SEGMENT_VECTOR)
							.withIndexType(IndexType.IVF_SQ8)
							.withMetricType(MetricType.L2)
							.withExtraParam("{\"nlist\":1024}")
							.build());
			R<RpcStatus> partition = milvusClient.createPartition(
					CreatePartitionParam.newBuilder()
							.withCollectionName(collectionName)
							.withPartitionName(collectionName + "_partition")
							.build()
			);

		}

		R<RpcStatus> rpcStatusR = milvusClient.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(collectionName).build());
		System.out.println(rpcStatusR.getData().getMsg());

	}

	@Override
	public String add(Embedding embedding) {
		String id = UUID.randomUUID().toString();
		add(id, embedding);
		return id;
	}

	@Override
	public void add(String id, Embedding embedding) {
		addInternal(id, embedding, null);
	}

	@Override
	public String add(Embedding embedding, TextSegment TextSegment) {
		String id = UUID.randomUUID().toString();
		addInternal(id, embedding, TextSegment);
		return id;
	}


	public String add(String id, Embedding embedding, TextSegment textSegment) {
		addInternal(id, embedding, textSegment);
		return id;
	}

	public List<String> add(String id, List<TextSegment> data, EmbeddingModel embeddingModel) {
		List<String> ids = new ArrayList<>();
		List<Embedding> embeddings = new ArrayList<>();
		List<TextSegment> textSegments = new ArrayList<>();
		for (TextSegment textSegment : data) {
			ids.add(id);
			embeddings.add(embeddingModel.embed(textSegment));
			textSegments.add(textSegment);
		}
		return addAllInternal(ids, embeddings, textSegments);
	}

	@Override
	public List<String> addAll(List<Embedding> embeddings) {
		return addAll(embeddings, null);
	}

	@Override
	public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegmentList) {
		return addAllInternal(embeddings.stream().map(MilvusEmbeddingStoreImpl::generateRandomId).collect(Collectors.toList()), embeddings, textSegmentList);
	}

	private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
		addAllInternal(singletonList(id), singletonList(embedding), textSegment == null ? null : singletonList(textSegment));
	}

	private List<String> addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {

		List<JSONObject> rows = new ArrayList<>();
		for (int i = 0; i < ids.size(); i++) {
			JSONObject row = new JSONObject();
			row.put(METADATA_TEXT_SEGMENT_ID, ids.get(i));
			row.put(METADATA_TEXT_SEGMENT, textSegments.get(i).text());
			row.put(METADATA_TEXT_SEGMENT_VECTOR, embeddings.get(i).vectorAsList());
			rows.add(row);
		}

		InsertRowsParam insertRowsParam = InsertRowsParam.newBuilder()
				.withCollectionName(collectionName)
				.withRows(rows)
				.build();

		R<InsertResponse> insert = milvusClient.insert(insertRowsParam);

		return (List<String>) insert.getData().getInsertIds();
	}

	@Override
	public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
		return findRelevant(referenceEmbedding, maxResults, 0.0);
	}

	@Override
	public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minSimilarity) {
		return findRelevant(null, referenceEmbedding, maxResults, minSimilarity);
	}

	public List<EmbeddingMatch<TextSegment>> findRelevant(String filterId, Embedding referenceEmbedding, int maxResults, double minSimilarity) {

		//todo: implement minSimilarity maybe

		List<List<Float>> searchVectors = Collections.singletonList(referenceEmbedding.vectorAsList());

		Builder seachBuilder = SearchParam.newBuilder()
				.withCollectionName(collectionName)
				.withTopK(maxResults)
				.withMetricType(MetricType.L2)
				.withVectorFieldName(METADATA_TEXT_SEGMENT + "_vector")
				.withOutFields(new ArrayList<>() {{
					add(METADATA_TEXT_SEGMENT + "_id");
					add(METADATA_TEXT_SEGMENT);
				}})
				.withVectors(searchVectors);

		if (filterId != null) {
			String expr = METADATA_TEXT_SEGMENT_ID + " == '" + filterId + "'";
			seachBuilder.withExpr(expr);
		}

		SearchParam searchParam = seachBuilder.build();

		R<SearchResults> respSearch = milvusClient.search(searchParam);

		if (respSearch.getStatus() != 0) {
			throw new RuntimeException("Search failed: " + respSearch.getMessage());
		}

		List<EmbeddingMatch<TextSegment>> embeddingMatches = new ArrayList<>();
		for (int i = 0; i < respSearch.getData().getResults().getScoresCount(); i++) {
			SearchResultData results = respSearch.getData().getResults();
			String id = null;
			String textSegment = null;
			float score = results.getScores(i);

			for (FieldData fieldData : results.getFieldsDataList()) {

				switch (fieldData.getFieldName()) {
					case METADATA_TEXT_SEGMENT_ID ->
							id = fieldData.getScalars().getStringData().getData(i);
					case METADATA_TEXT_SEGMENT ->
							textSegment = fieldData.getScalars().getStringData().getData(i);
					//todo: maybe add the vector later, but i dont need it right now
				}

			}
			embeddingMatches.add(new EmbeddingMatch<>((double) score, id, new Embedding(new float[0]), new TextSegment(textSegment, null)));
		}

		//just keep the the collection in memory

//        milvusClient.releaseCollection(
//            ReleaseCollectionParam.newBuilder()
//                .withCollectionName(collectionName)
//                .build());

		return embeddingMatches;
	}

	private static String generateRandomId(Embedding embedding) {
		return UUID.randomUUID().toString();
	}

}

@langchain4j
Copy link
Owner

@AbdullahGheith thanks a lot! I will incorporate some of your code into our current implementation.

@langchain4j
Copy link
Owner

@AbdullahGheith we've updated Milvus support, please check 0.23.0: https://github.com/langchain4j/langchain4j-examples/blob/main/milvus-example/src/main/java/MilvusEmbeddingStoreExample.java

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants