From c7feff1944f7adac9aeb0fe8b54edffce92f190f Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Fri, 26 Jul 2024 09:34:08 +0100 Subject: [PATCH] fix: handle null metadata docs in bq retriever --- .../vertexai/src/vector-search/bigquery.ts | 57 ++++-- .../tests/vector-search/bigquery_test.ts | 168 ++++++++++++++++++ .../src/index.ts | 2 +- .../src/index.ts | 2 +- .../src/index.ts | 2 +- 5 files changed, 213 insertions(+), 18 deletions(-) create mode 100644 js/plugins/vertexai/tests/vector-search/bigquery_test.ts diff --git a/js/plugins/vertexai/src/vector-search/bigquery.ts b/js/plugins/vertexai/src/vector-search/bigquery.ts index b823e6e999..e36224ece2 100644 --- a/js/plugins/vertexai/src/vector-search/bigquery.ts +++ b/js/plugins/vertexai/src/vector-search/bigquery.ts @@ -15,8 +15,11 @@ */ import { Document, DocumentDataSchema } from '@genkit-ai/ai/retriever'; -import { BigQuery } from '@google-cloud/bigquery'; +import { logger } from '@genkit-ai/core/logging'; +import { BigQuery, QueryRowsResponse } from '@google-cloud/bigquery'; +import { ZodError } from 'zod'; import { DocumentIndexer, DocumentRetriever, Neighbor } from './types'; + /** * Creates a BigQuery Document Retriever. * @@ -36,34 +39,58 @@ export const getBigQueryDocumentRetriever = ( const bigQueryRetriever: DocumentRetriever = async ( neighbors: Neighbor[] ): Promise => { - const ids = neighbors + const ids: string[] = neighbors .map((neighbor) => neighbor.datapoint?.datapointId) - .filter(Boolean); + .filter(Boolean) as string[]; + const query = ` SELECT * FROM \`${datasetId}.${tableId}\` WHERE id IN UNNEST(@ids) `; + const options = { query, params: { ids }, }; - const [rows] = await bq.query(options); - const docs: Document[] = rows - .map((row) => { - const docData = { + + let rows: QueryRowsResponse[0]; + + try { + [rows] = await bq.query(options); + } catch (queryError) { + logger.error('Failed to execute BigQuery query:', queryError); + return []; + } + + const documents: Document[] = []; + + for (const row of rows) { + try { + const docData: { content: any; metadata?: any } = { content: JSON.parse(row.content), - metadata: JSON.parse(row.metadata), }; - const parsedDocData = DocumentDataSchema.safeParse(docData); - if (parsedDocData.success) { - return new Document(parsedDocData.data); + + if (row.metadata) { + docData.metadata = JSON.parse(row.metadata); + } + + const parsedDocData = DocumentDataSchema.parse(docData); + documents.push(new Document(parsedDocData)); + } catch (error) { + const id = row.id; + const errorPrefix = `Failed to parse document data for document with ID ${id}:`; + + if (error instanceof ZodError || error instanceof Error) { + logger.warn(`${errorPrefix} ${error.message}`); + } else { + logger.warn(errorPrefix); } - return null; - }) - .filter((doc): doc is Document => !!doc); + } + } - return docs; + return documents; }; + return bigQueryRetriever; }; diff --git a/js/plugins/vertexai/tests/vector-search/bigquery_test.ts b/js/plugins/vertexai/tests/vector-search/bigquery_test.ts new file mode 100644 index 0000000000..69378ab809 --- /dev/null +++ b/js/plugins/vertexai/tests/vector-search/bigquery_test.ts @@ -0,0 +1,168 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Document } from '@genkit-ai/ai/retriever'; +import { BigQuery } from '@google-cloud/bigquery'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { getBigQueryDocumentRetriever } from '../../src'; + +class MockBigQuery { + query: Function; + + constructor({ + mockRows, + shouldThrowError = false, + }: { + mockRows: any[]; + shouldThrowError?: boolean; + }) { + this.query = async (_options: { + query: string; + params: { ids: string[] }; + }) => { + if (shouldThrowError) { + throw new Error('Query failed'); + } + return [mockRows]; + }; + } +} + +describe('getBigQueryDocumentRetriever', () => { + it('returns a function that retrieves documents from BigQuery', async () => { + const doc1 = Document.fromText('content1'); + const doc2 = Document.fromText('content2'); + + const mockRows = [ + { + id: '1', + content: JSON.stringify(doc1.content), + metadata: null, + }, + { + id: '2', + content: JSON.stringify(doc2.content), + metadata: null, + }, + ]; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '1' } }, + { datapoint: { datapointId: '2' } }, + ]); + + assert.deepStrictEqual(documents, [doc1, doc2]); + }); + + it('returns an empty array when no documents match', async () => { + const mockRows: any[] = []; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '3' } }, + ]); + + assert.deepStrictEqual(documents, []); + }); + + it('handles BigQuery query errors', async () => { + const mockBigQuery = new MockBigQuery({ + mockRows: [], + shouldThrowError: true, + }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + // no need to assert the error, just make sure it doesn't throw + await documentRetriever([{ datapoint: { datapointId: '1' } }]); + }); + + it('filters out invalid documents', async () => { + const validDoc = Document.fromText('valid content'); + const mockRows = [ + { + id: '1', + content: JSON.stringify(validDoc.content), + metadata: null, + }, + { + id: '2', + content: 'invalid JSON', + metadata: null, + }, + ]; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '1' } }, + { datapoint: { datapointId: '2' } }, + ]); + + assert.deepStrictEqual(documents, [validDoc]); + }); + + it('handles missing content in documents', async () => { + const validDoc = Document.fromText('valid content'); + const mockRows = [ + { + id: '1', + content: JSON.stringify(validDoc.content), + metadata: null, + }, + { + id: '2', + content: null, + metadata: null, + }, + ]; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '1' } }, + { datapoint: { datapointId: '2' } }, + ]); + + assert.deepStrictEqual(documents, [validDoc]); + }); +}); diff --git a/js/testapps/vertexai-vector-search-bigquery/src/index.ts b/js/testapps/vertexai-vector-search-bigquery/src/index.ts index b94b99553f..e0d3ab8993 100644 --- a/js/testapps/vertexai-vector-search-bigquery/src/index.ts +++ b/js/testapps/vertexai-vector-search-bigquery/src/index.ts @@ -84,7 +84,7 @@ configureGenkit({ googleAuth: { scopes: ['https://www.googleapis.com/auth/cloud-platform'], }, - vectorSearchIndexOptions: [ + vectorSearchOptions: [ { publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME, indexEndpointId: VECTOR_SEARCH_INDEX_ENDPOINT_ID, diff --git a/js/testapps/vertexai-vector-search-custom/src/index.ts b/js/testapps/vertexai-vector-search-custom/src/index.ts index 58c892368f..cbc70928ac 100644 --- a/js/testapps/vertexai-vector-search-custom/src/index.ts +++ b/js/testapps/vertexai-vector-search-custom/src/index.ts @@ -151,7 +151,7 @@ configureGenkit({ googleAuth: { scopes: ['https://www.googleapis.com/auth/cloud-platform'], }, - vectorSearchIndexOptions: [ + vectorSearchOptions: [ { publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME, indexEndpointId: VECTOR_SEARCH_INDEX_ENDPOINT_ID, diff --git a/js/testapps/vertexai-vector-search-firestore/src/index.ts b/js/testapps/vertexai-vector-search-firestore/src/index.ts index 438c9636b7..959b9ed2ba 100644 --- a/js/testapps/vertexai-vector-search-firestore/src/index.ts +++ b/js/testapps/vertexai-vector-search-firestore/src/index.ts @@ -83,7 +83,7 @@ configureGenkit({ googleAuth: { scopes: ['https://www.googleapis.com/auth/cloud-platform'], }, - vectorSearchIndexOptions: [ + vectorSearchOptions: [ { publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME, indexEndpointId: VECTOR_SEARCH_INDEX_ENDPOINT_ID,