Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[minor]: feat: QdrantTranslator for self-query retrieval #5163

Merged
merged 11 commits into from
Apr 26, 2024
18 changes: 18 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,10 @@
"retrievers/self_query/pinecone.js",
"retrievers/self_query/pinecone.d.ts",
"retrievers/self_query/pinecone.d.cts",
"retrievers/self_query/qdrant.cjs",
"retrievers/self_query/qdrant.js",
"retrievers/self_query/qdrant.d.ts",
"retrievers/self_query/qdrant.d.cts",
"retrievers/self_query/supabase.cjs",
"retrievers/self_query/supabase.js",
"retrievers/self_query/supabase.d.ts",
Expand Down Expand Up @@ -1232,6 +1236,7 @@
"@langchain/scripts": "~0.0",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I noticed that this PR adds a new regular dependency, "@qdrant/js-client-rest", to the project. I've flagged this for your review to ensure it aligns with our dependency management strategy. Keep up the great work!

"@notionhq/client": "^2.2.10",
"@pinecone-database/pinecone": "^1.1.0",
"@qdrant/js-client-rest": "^1.8.2",
"@supabase/supabase-js": "^2.10.0",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
Expand Down Expand Up @@ -1316,6 +1321,7 @@
"@google-cloud/storage": "^6.10.1 || ^7.7.0",
"@notionhq/client": "^2.2.10",
"@pinecone-database/pinecone": "*",
"@qdrant/js-client-rest": "^1.8.2",
"@supabase/supabase-js": "^2.10.0",
"@vercel/kv": "^0.2.3",
"@xata.io/client": "^0.28.0",
Expand Down Expand Up @@ -1392,6 +1398,9 @@
"@pinecone-database/pinecone": {
"optional": true
},
"@qdrant/js-client-rest": {
"optional": true
},
"@supabase/supabase-js": {
"optional": true
},
Expand Down Expand Up @@ -3608,6 +3617,15 @@
"import": "./retrievers/self_query/pinecone.js",
"require": "./retrievers/self_query/pinecone.cjs"
},
"./retrievers/self_query/qdrant": {
"types": {
"import": "./retrievers/self_query/qdrant.d.ts",
"require": "./retrievers/self_query/qdrant.d.cts",
"default": "./retrievers/self_query/qdrant.d.ts"
},
"import": "./retrievers/self_query/qdrant.js",
"require": "./retrievers/self_query/qdrant.cjs"
},
"./retrievers/self_query/supabase": {
"types": {
"import": "./retrievers/self_query/supabase.d.ts",
Expand Down
185 changes: 185 additions & 0 deletions langchain/src/retrievers/self_query/qdrant.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import {
QdrantVectorStore,
QdrantFilter,
QdrantCondition,
} from "@langchain/community/vectorstores/qdrant";

import {
Comparator,
Comparators,
Comparison,
Operation,
Operator,
Operators,
StructuredQuery,
Visitor,
} from "../../chains/query_constructor/ir.js";
import { BaseTranslator } from "./base.js";
import { isFilterEmpty, castValue, isInt, isFloat } from "./utils.js";

/**
* A class that translates or converts `StructuredQuery` to equivalent Qdrant filters.
* @example
* ```typescript
* const selfQueryRetriever = new SelfQueryRetriever({
* llm: new ChatOpenAI(),
* vectorStore: new QdrantVectorStore(...),
* documentContents: "Brief summary of a movie",
* attributeInfo: [],
* structuredQueryTranslator: new QdrantTranslator(),
* });
*
* const relevantDocuments = await selfQueryRetriever.getRelevantDocuments(
* "Which movies are rated higher than 8.5?",
* );
* ```
*/
export class QdrantTranslator<
T extends QdrantVectorStore
> extends BaseTranslator<T> {
declare VisitOperationOutput: QdrantFilter;

declare VisitComparisonOutput: QdrantCondition;

allowedOperators: Operator[] = [Operators.and, Operators.or, Operators.not];

allowedComparators: Comparator[] = [
Comparators.eq,
Comparators.ne,
Comparators.lt,
Comparators.lte,
Comparators.gt,
Comparators.gte,
];

/**
* Visits an operation and returns a QdrantFilter.
* @param operation The operation to visit.
* @returns A QdrantFilter.
*/
visitOperation(operation: Operation): this["VisitOperationOutput"] {
const args = operation.args?.map((arg) => arg.accept(this as Visitor));

const operator = {
[Operators.and]: "must",
[Operators.or]: "should",
[Operators.not]: "must_not",
}[operation.operator];

return {
[operator]: args,
};
}

/**
* Visits a comparison and returns a QdrantCondition.
* The value is casted to the correct type.
* The attribute is prefixed with "metadata.",
* since metadata is nested in the Qdrant payload.
* @param comparison The comparison to visit.
* @returns A QdrantCondition.
*/
visitComparison(comparison: Comparison): this["VisitComparisonOutput"] {
const attribute = `metadata.${comparison.attribute}`;
const value = castValue(comparison.value);

if (comparison.comparator === "eq") {
return {
key: attribute,
match: {
value,
},
};
} else if (comparison.comparator === "ne") {
return {
key: attribute,
match: {
except: [value],
},
};
}

if (!isInt(value) && !isFloat(value)) {
throw new Error("Value for gt, gte, lt, lte must be a number");
}

// For gt, gte, lt, lte, we need to use the range filter
return {
key: attribute,
range: {
[comparison.comparator]: value,
},
};
}

/**
* Visits a structured query and returns a VisitStructuredQueryOutput.
* If the query has a filter, it is visited.
* @param query The structured query to visit.
* @returns An instance of VisitStructuredQueryOutput.
*/
visitStructuredQuery(
query: StructuredQuery
): this["VisitStructuredQueryOutput"] {
let nextArg = {};
if (query.filter) {
nextArg = {
filter: { must: [query.filter.accept(this as Visitor)] },
};
}
return nextArg;
}

/**
* Merges two filters into one. If both filters are empty, returns
* undefined. If one filter is empty or the merge type is 'replace',
* returns the other filter. If the merge type is 'and' or 'or', returns a
* new filter with the merged results. Throws an error for unknown merge
* types.
* @param defaultFilter The default filter to merge.
* @param generatedFilter The generated filter to merge.
* @param mergeType The type of merge to perform. Can be 'and', 'or', or 'replace'. Defaults to 'and'.
* @param forceDefaultFilter If true, the default filter is always returned if the generated filter is empty. Defaults to false.
* @returns A merged QdrantFilter, or undefined if both filters are empty.
*/
mergeFilters(
defaultFilter: QdrantFilter | undefined,
generatedFilter: QdrantFilter | undefined,
mergeType = "and",
forceDefaultFilter = false
): QdrantFilter | undefined {
if (isFilterEmpty(defaultFilter) && isFilterEmpty(generatedFilter)) {
return undefined;
}
if (isFilterEmpty(defaultFilter) || mergeType === "replace") {
if (isFilterEmpty(generatedFilter)) {
return undefined;
}
return generatedFilter;
}
if (isFilterEmpty(generatedFilter)) {
if (forceDefaultFilter) {
return defaultFilter;
}
if (mergeType === "and") {
return undefined;
}
return defaultFilter;
}
if (mergeType === "and") {
return {
must: [defaultFilter, generatedFilter],
};
} else if (mergeType === "or") {
return {
should: [defaultFilter, generatedFilter],
};
} else {
throw new Error("Unknown merge type");
}
}

formatFunction(): string {
throw new Error("Not implemented");
}
}
Loading