From 4f9474d020e116afec80ff6db8e0d8ed5dc6522b Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Thu, 8 Feb 2024 21:21:42 -0800 Subject: [PATCH] mypy --- .../cloudsql_vectorstore.py | 26 +++++++++--------- .../postgresql_engine.py | 27 +++++++------------ 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py b/src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py index 99e34f4b..310ad0de 100644 --- a/src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py @@ -20,13 +20,12 @@ import uuid from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union -import nest_asyncio +import nest_asyncio # type: ignore import numpy as np from langchain_community.vectorstores.utils import maximal_marginal_relevance from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from pgvector.sqlalchemy import Vector from sqlalchemy import text from .indexes import ( @@ -218,7 +217,10 @@ async def aadd_embeddings( # ) async def aadd_documents( - self, documents: List[Document], ids: List[str] = None, **kwargs: Any + self, + documents: List[Document], + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] @@ -349,7 +351,7 @@ async def __query_collection( self, embedding: List[float], k: int = 4, - filter: str = None, + filter: Optional[str] = None, ) -> List[Any]: k = self.k if self.k else k if self.distance_strategy == DistanceStrategy.EUCLIDEAN: @@ -371,7 +373,7 @@ async def asimilarity_search( self, query: str, k: int = 4, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Document]: embedding = self.embedding_service.embed_query(text=query) @@ -384,7 +386,7 @@ def similarity_search( self, query: str, k: int = 4, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Document]: return self.loop.create_task( @@ -395,7 +397,7 @@ async def asimilarity_search_with_score( self, query: str, k: int = 4, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: embedding = self.embedding_service.embed_query(query) @@ -417,7 +419,7 @@ async def asimilarity_search_by_vector( self, embedding: List[float], k: int = 4, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = await self.asimilarity_search_with_score_by_vector( @@ -439,7 +441,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: results = await self.__query_collection(embedding=embedding, k=k, filter=filter) @@ -484,7 +486,7 @@ async def amax_marginal_relevance_search( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Document]: embedding = self.embedding_service.embed_query(text=query) @@ -519,7 +521,7 @@ async def amax_marginal_relevance_search_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: str = None, + filter: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -542,7 +544,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: str = None, + filter: Optional[str] = None, ) -> List[Tuple[Document, float]]: results = await self.__query_collection( embedding=embedding, k=fetch_k, filter=filter diff --git a/src/langchain_google_cloud_sql_pg/postgresql_engine.py b/src/langchain_google_cloud_sql_pg/postgresql_engine.py index 97fe8be2..79d7a1c2 100644 --- a/src/langchain_google_cloud_sql_pg/postgresql_engine.py +++ b/src/langchain_google_cloud_sql_pg/postgresql_engine.py @@ -17,21 +17,22 @@ # import requests # import sqlalchemy -import asyncio +import asyncio # type: ignore from threading import Thread from typing import TYPE_CHECKING, Dict, List, Optional, Type import aiohttp -import google.auth -import google.auth.transport.requests -import nest_asyncio +import google.auth # type: ignore +import google.auth.transport.requests # type: ignore from google.cloud.sql.connector import Connector, create_async_connector # from pgvector.asyncpg import register_vector from sqlalchemy import Column, text -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine - -# nest_asyncio.apply() +from sqlalchemy.ext.asyncio import ( + AsyncConnection, + AsyncEngine, + create_async_engine, +) if TYPE_CHECKING: import asyncpg @@ -109,7 +110,7 @@ def from_instance( region: str, instance: str, database: str, - project_id: str = None, + project_id: Optional[str] = None, ) -> PostgreSQLEngine: """Create PostgreSQLEngine connection to the postgres database in the CloudSQL instance. Args: @@ -148,9 +149,7 @@ async def get_conn(): conn = await connector.connect_async( f"{self.project_id}:{self.region}:{self.instance}", "asyncpg", - # user=await _get_iam_principal_email(credentials), - user="postgres", - password="my-pg-pass", + user=await _get_iam_principal_email(credentials), enable_iam_auth=True, db=self.database, ) @@ -189,18 +188,12 @@ async def init_vectorstore_table( overwrite_existing: bool = False, store_metadata: bool = True, ) -> None: - # async with self.engine.connect() as conn: - # Enable pgvector - # await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) await self._aexecute_update("CREATE EXTENSION IF NOT EXISTS vector") # Register the vector type # await register_vector(conn) if overwrite_existing: await self._aexecute_update(f"DROP TABLE {table_name}") - # await conn.execute( - # text(f"TRUNCATE TABLE {table_name} RESET IDENTITY") - # ) # TODO? query = f"""CREATE TABLE IF NOT EXISTS {table_name}( {id_column} UUID PRIMARY KEY,