Skip to content

Commit

Permalink
Allow multiple metadata keys on RedisVectorStoreFilterType langchain-…
Browse files Browse the repository at this point in the history
  • Loading branch information
mauriciocirelli committed Apr 9, 2024
1 parent aec025f commit f79d7d6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 65 deletions.
8 changes: 5 additions & 3 deletions libs/langchain-redis/src/tests/vectorstores.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/* eslint-disable no-promise-executor-return */

import { RedisClientType, createClient } from "redis";
import { SchemaFieldTypes } from "redis";
import { v4 as uuidv4 } from "uuid";
import { test, expect } from "@jest/globals";
import { faker } from "@faker-js/faker";
Expand All @@ -21,6 +22,9 @@ describe("RedisVectorStore", () => {
redisClient: client as RedisClientType,
indexName: "test-index",
keyPrefix: "test:",
metadataSchema: {
["foo"]: SchemaFieldTypes.TEXT,
}
});
});

Expand Down Expand Up @@ -66,9 +70,7 @@ describe("RedisVectorStore", () => {
]);

// If the filter wasn't working, we'd get all 3 documents back
const results = await vectorStore.similaritySearch(pageContent, 3, [
`${uuid}`,
]);
const results = await vectorStore.similaritySearch(pageContent, 3, `@foo:(${uuid})`);

expect(results).toEqual([
new Document({ metadata: { foo: uuid }, pageContent }),
Expand Down
46 changes: 12 additions & 34 deletions libs/langchain-redis/src/tests/vectorstores.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { jest, test, expect, describe } from "@jest/globals";
import { FakeEmbeddings } from "@langchain/core/utils/testing";

import { RedisVectorStore } from "../vectorstores.js";
import { SchemaFieldTypes } from "redis";

const createRedisClientMockup = () => {
const hSetMock = jest.fn();
Expand Down Expand Up @@ -34,6 +35,9 @@ test("RedisVectorStore with external keys", async () => {
const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
metadataSchema: {
["a"]: SchemaFieldTypes.NUMERIC,
}
});

expect(store).toBeDefined();
Expand All @@ -44,7 +48,6 @@ test("RedisVectorStore with external keys", async () => {
pageContent: "hello",
metadata: {
a: 1,
b: { nested: [1, { a: 4 }] },
},
},
],
Expand All @@ -55,7 +58,7 @@ test("RedisVectorStore with external keys", async () => {
expect(client.hSet).toHaveBeenCalledWith("id1", {
content_vector: Buffer.from(new Float32Array([0.1, 0.2, 0.3, 0.4]).buffer),
content: "hello",
metadata: `{\\"a\\"\\:1,\\"b\\"\\:{\\"nested\\"\\:[1,{\\"a\\"\\:4}]}}`,
a: 1,
});

const results = await store.similaritySearch("goodbye", 1);
Expand All @@ -70,6 +73,9 @@ test("RedisVectorStore with generated keys", async () => {
const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
metadataSchema: {
["a"]: SchemaFieldTypes.NUMERIC,
}
});

expect(store).toBeDefined();
Expand All @@ -90,46 +96,18 @@ test("RedisVectorStore with filters", async () => {
const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
});

expect(store).toBeDefined();

await store.similaritySearch("hello", 1, ["a", "b", "c"]);

expect(client.ft.search).toHaveBeenCalledWith(
"documents",
"@metadata:(a|b|c) => [KNN 1 @content_vector $vector AS vector_score]",
{
PARAMS: {
vector: Buffer.from(new Float32Array([0.1, 0.2, 0.3, 0.4]).buffer),
},
RETURN: ["metadata", "content", "vector_score"],
SORTBY: "vector_score",
DIALECT: 2,
LIMIT: {
from: 0,
size: 1,
},
metadataSchema: {
["metadata"]: SchemaFieldTypes.TEXT
}
);
});

test("RedisVectorStore with raw filter", async () => {
const client = createRedisClientMockup();
const embeddings = new FakeEmbeddings();

const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
});

expect(store).toBeDefined();

await store.similaritySearch("hello", 1, "a b c");
await store.similaritySearch("hello", 1, "@metadata:(a|b|c)");

expect(client.ft.search).toHaveBeenCalledWith(
"documents",
"@metadata:(a b c) => [KNN 1 @content_vector $vector AS vector_score]",
"@metadata:(a|b|c) => [KNN 1 @content_vector $vector AS vector_score]",
{
PARAMS: {
vector: Buffer.from(new Float32Array([0.1, 0.2, 0.3, 0.4]).buffer),
Expand Down
57 changes: 29 additions & 28 deletions libs/langchain-redis/src/vectorstores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ export interface RedisVectorStoreConfig {
createIndexOptions?: Omit<RedisVectorStoreIndexOptions, "PREFIX">; // PREFIX must be set with keyPrefix
keyPrefix?: string;
contentKey?: string;
metadataKey?: string;
vectorKey?: string;
metadataSchema?: RediSearchSchema;
filter?: RedisVectorStoreFilterType;
}

Expand All @@ -88,12 +88,9 @@ export interface RedisAddOptions {
}

/**
* Type for the filter used in the RedisVectorStore. It is an array of
* strings.
* If a string is passed instead of an array the value is used directly, this
* allows custom filters to be passed.
* Type for the filter used in the RedisVectorStore. It is a Redis filter setence, such as @field:{value1}.
*/
export type RedisVectorStoreFilterType = string[] | string;
export type RedisVectorStoreFilterType = string;

/**
* Class representing a RedisVectorStore. It extends the VectorStore class
Expand All @@ -117,10 +114,10 @@ export class RedisVectorStore extends VectorStore {

contentKey: string;

metadataKey: string;

vectorKey: string;

metadataSchema: RediSearchSchema;

filter?: RedisVectorStoreFilterType;

_vectorstoreType(): string {
Expand All @@ -141,8 +138,8 @@ export class RedisVectorStore extends VectorStore {
};
this.keyPrefix = _dbConfig.keyPrefix ?? `doc:${this.indexName}:`;
this.contentKey = _dbConfig.contentKey ?? "content";
this.metadataKey = _dbConfig.metadataKey ?? "metadata";
this.vectorKey = _dbConfig.vectorKey ?? "content_vector";
this.metadataSchema = _dbConfig.metadataSchema ?? {};
this.filter = _dbConfig.filter;
this.createIndexOptions = {
ON: "HASH",
Expand Down Expand Up @@ -185,6 +182,7 @@ export class RedisVectorStore extends VectorStore {
if (!vectors.length || !vectors[0].length) {
throw new Error("No vectors provided");
}

// check if the index exists and create it if it doesn't
await this.createIndex(vectors[0].length);

Expand All @@ -202,12 +200,19 @@ export class RedisVectorStore extends VectorStore {
? documents[idx].metadata
: {};

multi.hSet(key, {
[this.vectorKey]: this.getFloat32Buffer(vector),
[this.contentKey]: documents[idx].pageContent,
[this.metadataKey]: this.escapeSpecialChars(JSON.stringify(metadata)),
var t = {
[this.vectorKey]: this.getFloat32Buffer(vector),
[this.contentKey]: documents[idx].pageContent,
};

Object.keys(this.metadataSchema).forEach((key) => {
if(metadata[key]) {
t[key] = (Array.isArray(metadata[key])) ? this.escapeSpecialChars(metadata[key].map((val: any) => val.toString()).join(",")) : this.escapeSpecialChars(metadata[key].toString());
}
});

multi.hSet(key, t);

// write batch
if (idx % batchSize === 0) {
await multi.exec();
Expand Down Expand Up @@ -361,9 +366,12 @@ export class RedisVectorStore extends VectorStore {
...this.indexOptions,
},
[this.contentKey]: SchemaFieldTypes.TEXT,
[this.metadataKey]: SchemaFieldTypes.TEXT,
};

Object.keys(this.metadataSchema).forEach((key) => {
schema[key] = this.metadataSchema[key];
});

await this.redisClient.ft.create(
this.indexName,
schema,
Expand Down Expand Up @@ -407,16 +415,16 @@ export class RedisVectorStore extends VectorStore {
): [string, SearchOptions] {
const vectorScoreField = "vector_score";

let hybridFields = "*";
let hybridFields: string;
// if a filter is set, modify the hybrid query
if (filter && filter.length) {
// `filter` is a list of strings, then it's applied using the OR operator in the metadata key
// for example: filter = ['foo', 'bar'] => this will filter all metadata containing either 'foo' OR 'bar'
hybridFields = `@${this.metadataKey}:(${this.prepareFilter(filter)})`;
if (typeof filter === "string") {
hybridFields = `${filter}`;
} else {
hybridFields = "*";
}

const baseQuery = `${hybridFields} => [KNN ${k} @${this.vectorKey} $vector AS ${vectorScoreField}]`;
const returnFields = [this.metadataKey, this.contentKey, vectorScoreField];
const returnFields = [...Object.keys(this.metadataSchema), this.contentKey, vectorScoreField];

const options: SearchOptions = {
PARAMS: {
Expand All @@ -434,13 +442,6 @@ export class RedisVectorStore extends VectorStore {
return [baseQuery, options];
}

private prepareFilter(filter: RedisVectorStoreFilterType) {
if (Array.isArray(filter)) {
return filter.map(this.escapeSpecialChars).join("|");
}
return filter;
}

/**
* Escapes all '-', ':', and '"' characters.
* RediSearch considers these all as special characters, so we need
Expand Down Expand Up @@ -480,4 +481,4 @@ export class RedisVectorStore extends VectorStore {
private getFloat32Buffer(vector: number[]) {
return Buffer.from(new Float32Array(vector).buffer);
}
}
}

0 comments on commit f79d7d6

Please sign in to comment.