Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Rework support of AstraDB and Cassandra (#548)
In the Datastax Astra DB saas solution, a new way to integrate with
vector databases has been introduced: using an HTTP APi instead of the
Cassandra Cluster. It is called the DataAPI and use the MongoDB
principles with collections.

The pull request includes the following:

### Update on previous implementations

- Previous implementations of embedding stores have been grouped in a
single `CassandraEmbeddingStore`. It can be instantiated for Astra or
OSS Cassandra based on 2 different constructor builders but everything
else is the same.

- Previous implementations of chat memory stores have been grouped in a
single `CassandraChatMemoryStore`. It can be instantiated for Astra or
OSS Cassandra based on 2 different constructor builders but everything
else is the same.

- Integration test for OSS Cassandra now using test containers (as
Cassandra 5-alpha2 image is out)

- Usage
```java
// Using with Astra (Cassandra AAS in the cloud)
CassandraEmbeddingStore.builderAstra()
  .token(token)
  .databaseId(dbId)
  .databaseRegion(TEST_REGION)
  .keyspace(KEYSPACE)
  .table(TEST_INDEX)
  .dimension(11)
  .metric(CassandraSimilarityMetric.COSINE)
  .build();

// Using OSS Cassandra
CassandraEmbeddingStore.builder()
  .contactPoints(Arrays.asList(contactPoint.getHostName()))
  .port(contactPoint.getPort())
  .localDataCenter(DATACENTER)
  .keyspace(KEYSPACE)
  .table(TEST_INDEX)
  .dimension(11)
  .metric(CassandraSimilarityMetric.COSINE)
  .build();
```

-Adding jdk11 in the pom

```
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
```

- introducing `insertMany()`,  distributed to all bulk loading

- Extending the variables `EmbeddingStoreIT`

- Using `MessageWindowChatMemory` for the tests.
  • Loading branch information
clun committed Feb 8, 2024
1 parent b375c7b commit cd006b1
Show file tree
Hide file tree
Showing 25 changed files with 1,666 additions and 1,040 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Expand Up @@ -15,7 +15,7 @@ jobs:
java_version: [8, 11, 17, 21]
include:
- java_version: '8'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch'
- java_version: '11'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j'
- java_version: '17'
Expand Down
76 changes: 49 additions & 27 deletions langchain4j-cassandra/pom.xml
Expand Up @@ -15,7 +15,11 @@
</parent>

<properties>
<astra-sdk.version>0.6.11</astra-sdk.version>
<astra-db-client.version>1.2.4</astra-db-client.version>
<jackson.version>2.16.1</jackson.version>
<logback.version>1.4.14</logback.version>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>

<dependencies>
Expand All @@ -25,6 +29,18 @@
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>

<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-db-client</artifactId>
<version>${astra-db-client.version}</version>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
Expand All @@ -36,53 +52,59 @@
<artifactId>slf4j-api</artifactId>
</dependency>

<!-- TESTS -->

<!-- Visibility for EmbeddingStoreIT -->
<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-sdk-vector</artifactId>
<version>${astra-sdk.version}</version>
<exclusions>
<exclusion>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</exclusion>
</exclusions>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<!-- removing cve -->
<!-- Same embeddings model to keep the 1% -->
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20231013</version>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-beanutils</groupId>
<artifactId>commons-beanutils</artifactId>
<version>1.9.4</version>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${project.parent.version}</version>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${project.parent.version}</version>
<groupId>org.testcontainers</groupId>
<artifactId>cassandra</artifactId>
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>

Expand Down
@@ -0,0 +1,243 @@
package dev.langchain4j.store.embedding.astradb;

import com.dtsx.astra.sdk.AstraDBCollection;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.stargate.sdk.data.domain.JsonDocument;
import io.stargate.sdk.data.domain.JsonDocumentMutationResult;
import io.stargate.sdk.data.domain.JsonDocumentResult;
import io.stargate.sdk.data.domain.odm.Document;
import io.stargate.sdk.data.domain.query.Filter;
import io.stargate.sdk.data.domain.query.SelectQuery;
import io.stargate.sdk.data.domain.query.SelectQueryBuilder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Implementation of {@link EmbeddingStore} using AstraDB.
*
* @see EmbeddingStore
*/
@Slf4j
@Getter @Setter
@Accessors(fluent = true)
public class AstraDbEmbeddingStore implements EmbeddingStore<TextSegment> {

/**
* Saving the text chunk as an attribut.
*/
public static final String KEY_ATTRIBUTES_BLOB = "body_blob";

/**
* Metadata used for similarity.
*/
public static final String KEY_SIMILARITY = "$similarity";

/**
* Client to work with an Astra Collection
*/
private final AstraDBCollection astraDBCollection;

/**
* Bulk loading are processed in chunks, size of 1 chunk in between 1 and 20
*/
private final int itemsPerChunk;

/**
* Bulk loading is distributed,the is the number threads
*/
private final int concurrentThreads;

/**
* Initialization of the store with an EXISTING collection.
*
* @param client
* astra db collection client
*/
public AstraDbEmbeddingStore(@NonNull AstraDBCollection client) {
this(client, 20, 8);
}

/**
* Initialization of the store with an EXISTING collection.
*
* @param client
* astra db collection client
* @param itemsPerChunk
* size of 1 chunk in between 1 and 20
*/
public AstraDbEmbeddingStore(@NonNull AstraDBCollection client, int itemsPerChunk, int concurrentThreads) {
if (itemsPerChunk>20 || itemsPerChunk<1) {
throw new IllegalArgumentException("'itemsPerChunk' should be in between 1 and 20");
}
if (concurrentThreads<1) {
throw new IllegalArgumentException("'concurrentThreads' should be at least 1");
}
this.astraDBCollection = client;
this.itemsPerChunk = itemsPerChunk;
this.concurrentThreads = concurrentThreads;
}

/**
* Delete all records from the table.
*/
public void clear() {
astraDBCollection.deleteAll();
}

/** {@inheritDoc} */
@Override
public String add(Embedding embedding) {
return add(embedding, null);
}

/** {@inheritDoc} */
@Override
public String add(Embedding embedding, TextSegment textSegment) {
return astraDBCollection
.insertOne(mapRecord(embedding, textSegment))
.getDocument().getId();
}

/** {@inheritDoc} */
@Override
public void add(String id, Embedding embedding) {
astraDBCollection.upsertOne(new JsonDocument().id(id).vector(embedding.vector()));
}

/** {@inheritDoc} */
@Override
public List<String> addAll(List<Embedding> embeddings) {
if (embeddings == null) return null;

// Map as a JsonDocument list.
List<JsonDocument> recordList = embeddings
.stream()
.map(e -> mapRecord(e, null))
.collect(Collectors.toList());

// No upsert needed as ids will be generated.
return astraDBCollection
.insertManyChunkedJsonDocuments(recordList, itemsPerChunk, concurrentThreads)
.stream()
.map(JsonDocumentMutationResult::getDocument)
.map(Document::getId)
.collect(Collectors.toList());
}

/**
* Add multiple embeddings as a single action.
*
* @param embeddingList
* list of embeddings
* @param textSegmentList
* list of text segment
*
* @return list of new row if (same order as the input)
*/
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
if (embeddingList == null || textSegmentList == null || embeddingList.size() != textSegmentList.size()) {
throw new IllegalArgumentException("embeddingList and textSegmentList must not be null and have the same size");
}

// Map as JsonDocument list
List<JsonDocument> recordList = new ArrayList<>();
for (int i = 0; i < embeddingList.size(); i++) {
recordList.add(mapRecord(embeddingList.get(i), textSegmentList.get(i)));
}

// No upsert needed (ids will be generated)
return astraDBCollection
.insertManyChunkedJsonDocuments(recordList, itemsPerChunk, concurrentThreads)
.stream()
.map(JsonDocumentMutationResult::getDocument)
.map(Document::getId)
.collect(Collectors.toList());
}

/** {@inheritDoc} */
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
return findRelevant(referenceEmbedding, (Filter) null, maxResults, minScore);
}

/**
* Semantic search with metadata filtering.
*
* @param referenceEmbedding
* vector
* @param metaDatafilter
* fileter for metadata
* @param maxResults
* limit
* @param minScore
* threshold
* @return
* records
*/
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, Filter metaDatafilter, int maxResults, double minScore) {
return astraDBCollection.findVector(referenceEmbedding.vector(), metaDatafilter, maxResults)
.filter(r -> r.getSimilarity() >= minScore)
.map(this::mapJsonResult)
.collect(Collectors.toList());
}

/**
* Mapping the output of the query to a {@link EmbeddingMatch}..
*
* @param jsonRes
* returned object as Json
* @return
* embedding match as expected by langchain4j
*/
private EmbeddingMatch<TextSegment> mapJsonResult(JsonDocumentResult jsonRes) {
Double score = (double) jsonRes.getSimilarity();
String embeddingId = jsonRes.getId();
Embedding embedding = Embedding.from(jsonRes.getVector());
TextSegment embedded = null;
Map<String, Object> properties = jsonRes.getData();
if (properties!= null) {
Object body = properties.get(KEY_ATTRIBUTES_BLOB);
if (body != null) {
Metadata metadata = new Metadata(properties.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey,
entry -> entry.getValue() == null ? "" : entry.getValue().toString()
)));
metadata.remove(KEY_ATTRIBUTES_BLOB);
metadata.remove(KEY_SIMILARITY);
embedded = new TextSegment(body.toString(), metadata);
}
}
return new EmbeddingMatch<TextSegment>(score, embeddingId, embedding, embedded);
}

/**
* Map from LangChain4j record to AstraDB record.
*
* @param embedding
* embedding (vector)
* @param textSegment
* text segment (text to encode)
* @return
* a json document
*/
private JsonDocument mapRecord(Embedding embedding, TextSegment textSegment) {
JsonDocument record = new JsonDocument().vector(embedding.vector());
if (textSegment != null) {
record.put(KEY_ATTRIBUTES_BLOB, textSegment.text());
textSegment.metadata().asMap().forEach(record::put);
}
return record;
}

}

0 comments on commit cd006b1

Please sign in to comment.