Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
averikitsch committed Feb 9, 2024
1 parent 37a0ec2 commit 4f9474d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 29 deletions.
26 changes: 14 additions & 12 deletions src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
import uuid
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union

import nest_asyncio
import nest_asyncio # type: ignore
import numpy as np
from langchain_community.vectorstores.utils import maximal_marginal_relevance
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pgvector.sqlalchemy import Vector
from sqlalchemy import text

from .indexes import (
Expand Down Expand Up @@ -218,7 +217,10 @@ async def aadd_embeddings(
# )

async def aadd_documents(
self, documents: List[Document], ids: List[str] = None, **kwargs: Any
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
Expand Down Expand Up @@ -349,7 +351,7 @@ async def __query_collection(
self,
embedding: List[float],
k: int = 4,
filter: str = None,
filter: Optional[str] = None,
) -> List[Any]:
k = self.k if self.k else k
if self.distance_strategy == DistanceStrategy.EUCLIDEAN:
Expand All @@ -371,7 +373,7 @@ async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
embedding = self.embedding_service.embed_query(text=query)
Expand All @@ -384,7 +386,7 @@ def similarity_search(
self,
query: str,
k: int = 4,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
return self.loop.create_task(
Expand All @@ -395,7 +397,7 @@ async def asimilarity_search_with_score(
self,
query: str,
k: int = 4,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
embedding = self.embedding_service.embed_query(query)
Expand All @@ -417,7 +419,7 @@ async def asimilarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
docs_and_scores = await self.asimilarity_search_with_score_by_vector(
Expand All @@ -439,7 +441,7 @@ async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
results = await self.__query_collection(embedding=embedding, k=k, filter=filter)
Expand Down Expand Up @@ -484,7 +486,7 @@ async def amax_marginal_relevance_search(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
embedding = self.embedding_service.embed_query(text=query)
Expand Down Expand Up @@ -519,7 +521,7 @@ async def amax_marginal_relevance_search_by_vector(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: str = None,
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
Expand All @@ -542,7 +544,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: str = None,
filter: Optional[str] = None,
) -> List[Tuple[Document, float]]:
results = await self.__query_collection(
embedding=embedding, k=fetch_k, filter=filter
Expand Down
27 changes: 10 additions & 17 deletions src/langchain_google_cloud_sql_pg/postgresql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,22 @@

# import requests
# import sqlalchemy
import asyncio
import asyncio # type: ignore
from threading import Thread
from typing import TYPE_CHECKING, Dict, List, Optional, Type

import aiohttp
import google.auth
import google.auth.transport.requests
import nest_asyncio
import google.auth # type: ignore
import google.auth.transport.requests # type: ignore
from google.cloud.sql.connector import Connector, create_async_connector

# from pgvector.asyncpg import register_vector
from sqlalchemy import Column, text
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine

# nest_asyncio.apply()
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
create_async_engine,
)

if TYPE_CHECKING:
import asyncpg
Expand Down Expand Up @@ -109,7 +110,7 @@ def from_instance(
region: str,
instance: str,
database: str,
project_id: str = None,
project_id: Optional[str] = None,
) -> PostgreSQLEngine:
"""Create PostgreSQLEngine connection to the postgres database in the CloudSQL instance.
Args:
Expand Down Expand Up @@ -148,9 +149,7 @@ async def get_conn():
conn = await connector.connect_async(
f"{self.project_id}:{self.region}:{self.instance}",
"asyncpg",
# user=await _get_iam_principal_email(credentials),
user="postgres",
password="my-pg-pass",
user=await _get_iam_principal_email(credentials),
enable_iam_auth=True,
db=self.database,
)
Expand Down Expand Up @@ -189,18 +188,12 @@ async def init_vectorstore_table(
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
# async with self.engine.connect() as conn:
# Enable pgvector
# await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await self._aexecute_update("CREATE EXTENSION IF NOT EXISTS vector")
# Register the vector type
# await register_vector(conn)

if overwrite_existing:
await self._aexecute_update(f"DROP TABLE {table_name}")
# await conn.execute(
# text(f"TRUNCATE TABLE {table_name} RESET IDENTITY")
# ) # TODO?

query = f"""CREATE TABLE IF NOT EXISTS {table_name}(
{id_column} UUID PRIMARY KEY,
Expand Down

0 comments on commit 4f9474d

Please sign in to comment.