Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: Add mypy function type check #134

Merged
merged 12 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ profile = "black"
[tool.mypy]
python_version = "3.8"
warn_unused_configs = true
disallow_incomplete_defs = true
averikitsch marked this conversation as resolved.
Show resolved Hide resolved

exclude = [
'docs/*',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AlloyDBChatMessageHistory(BaseChatMessageHistory):

def __init__(
self,
key,
key: object,
engine: AlloyDBEngine,
session_id: str,
table_name: str,
Expand All @@ -66,7 +66,7 @@ async def create(
engine: AlloyDBEngine,
session_id: str,
table_name: str,
):
) -> AlloyDBChatMessageHistory:
table_schema = await engine._aload_table_schema(table_name)
column_names = table_schema.columns.keys()

Expand All @@ -93,7 +93,7 @@ def create_sync(
engine: AlloyDBEngine,
session_id: str,
table_name: str,
):
) -> AlloyDBChatMessageHistory:
coro = cls.create(engine, session_id, table_name)
return engine._run_as_sync(coro)

Expand Down
45 changes: 28 additions & 17 deletions src/langchain_google_alloydb_pg/alloydb_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from threading import Thread
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
)

import aiohttp
import google.auth # type: ignore
import google.auth.transport.requests # type: ignore
from google.cloud.alloydb.connector import AsyncConnector, IPTypes
from sqlalchemy import MetaData, Table, text
from sqlalchemy import MetaData, RowMapping, Table, text
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

Expand Down Expand Up @@ -81,7 +91,7 @@ class Column:
data_type: str
nullable: bool = True

def __post_init__(self):
def __post_init__(self) -> None:
if not isinstance(self.name, str):
raise ValueError("Column name must be type string")
if not isinstance(self.data_type, str):
Expand All @@ -98,14 +108,14 @@ def __init__(
engine: AsyncEngine,
loop: Optional[asyncio.AbstractEventLoop],
thread: Optional[Thread],
):
) -> None:
self._engine = engine
self._loop = loop
self._thread = thread

@classmethod
def from_instance(
cls,
cls: Type[AlloyDBEngine],
project_id: str,
region: str,
cluster: str,
Expand Down Expand Up @@ -136,7 +146,7 @@ def from_instance(

@classmethod
async def _create(
cls,
cls: Type[AlloyDBEngine],
project_id: str,
region: str,
cluster: str,
Expand Down Expand Up @@ -193,7 +203,7 @@ async def getconn() -> asyncpg.Connection:

@classmethod
async def afrom_instance(
cls,
cls: Type[AlloyDBEngine],
project_id: str,
region: str,
cluster: str,
Expand All @@ -215,22 +225,24 @@ async def afrom_instance(
)

@classmethod
def from_engine(cls, engine: AsyncEngine) -> AlloyDBEngine:
def from_engine(cls: Type[AlloyDBEngine], engine: AsyncEngine) -> AlloyDBEngine:
return cls(engine, None, None)

async def _aexecute(self, query: str, params: Optional[dict] = None):
async def _aexecute(self, query: str, params: Optional[dict] = None) -> None:
"""Execute a SQL query."""
async with self._engine.connect() as conn:
await conn.execute(text(query), params)
await conn.commit()

async def _aexecute_outside_tx(self, query: str):
async def _aexecute_outside_tx(self, query: str) -> None:
"""Execute a SQL query."""
async with self._engine.connect() as conn:
await conn.execute(text("COMMIT"))
await conn.execute(text(query))

async def _afetch(self, query: str, params: Optional[dict] = None):
async def _afetch(
self, query: str, params: Optional[dict] = None
) -> Sequence[RowMapping]:
async with self._engine.connect() as conn:
"""Fetch results from a SQL query."""
result = await conn.execute(text(query), params)
Expand All @@ -239,10 +251,10 @@ async def _afetch(self, query: str, params: Optional[dict] = None):

return result_fetch

def _execute(self, query: str, params: Optional[dict] = None):
def _execute(self, query: str, params: Optional[dict] = None) -> None:
return self._run_as_sync(self._aexecute(query, params))

def _fetch(self, query: str, params: Optional[dict] = None):
def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]:
return self._run_as_sync(self._afetch(query, params))

def _run_as_sync(self, coro: Awaitable[T]) -> T:
Expand Down Expand Up @@ -306,7 +318,7 @@ def init_vectorstore_table(
)
)

async def ainit_chat_history_table(self, table_name) -> None:
async def ainit_chat_history_table(self, table_name: str) -> None:
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"(
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
Expand All @@ -315,7 +327,7 @@ async def ainit_chat_history_table(self, table_name) -> None:
);"""
await self._aexecute(create_table_query)

def init_chat_history_table(self, table_name) -> None:
def init_chat_history_table(self, table_name: str) -> None:
return self._run_as_sync(
self.ainit_chat_history_table(
table_name,
Expand All @@ -340,7 +352,6 @@ async def ainit_document_table(
store_metadata (bool): Whether to store extra metadata in a metadata column
if not described in 'metadata' field list (Default: True).
"""

query = f"""CREATE TABLE "{table_name}"(
{content_column} TEXT NOT NULL
"""
Expand Down
53 changes: 28 additions & 25 deletions src/langchain_google_alloydb_pg/alloydb_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Iterator,
List,
Optional,
Type,
)

import sqlalchemy
Expand All @@ -36,22 +37,22 @@
DEFAULT_METADATA_COL = "langchain_metadata"


def text_formatter(row, content_columns) -> str:
def text_formatter(row: Dict[str, Any], content_columns: Iterable[str]) -> str:
return " ".join(str(row[column]) for column in content_columns if column in row)


def csv_formatter(row, content_columns) -> str:
def csv_formatter(row: Dict[str, Any], content_columns: Iterable[str]) -> str:
return ", ".join(str(row[column]) for column in content_columns if column in row)


def yaml_formatter(row, content_columns) -> str:
def yaml_formatter(row: Dict[str, Any], content_columns: Iterable[str]) -> str:
return "\n".join(
f"{column}: {str(row[column])}" for column in content_columns if column in row
)


def json_formatter(row, content_columns) -> str:
dictionary = {}
def json_formatter(row: Dict[str, Any], content_columns: Iterable[str]) -> str:
dictionary: Dict[str, Any] = {}
for column in content_columns:
if column in row:
dictionary[column] = row[column]
Expand All @@ -61,9 +62,9 @@ def json_formatter(row, content_columns) -> str:
def _parse_doc_from_row(
content_columns: Iterable[str],
metadata_columns: Iterable[str],
row: dict,
row: Dict[str, Any],
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
formatter: Callable = text_formatter,
formatter: Callable[[Dict[str, Any], Iterable[str]], str] = text_formatter,
) -> Document:
page_content = formatter(row, content_columns)
metadata: Dict[str, Any] = {}
Expand All @@ -84,7 +85,7 @@ def _parse_row_from_doc(
column_names: Iterable[str],
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> Dict:
) -> Dict[str, Any]:
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
Expand All @@ -98,10 +99,10 @@ def _parse_row_from_doc(


class AlloyDBLoader(BaseLoader):
"""Load documents from Alloydb`.
"""Load documents from AlloyDB`.

Each document represents one row of the result. The `content_columns` are
written into the `content_columns`of the document. The `metadata_columns` are written
written into the `content_columns` of the document. The `metadata_columns` are written
into the `metadata_columns` of the document. By default, first columns is written into
the `page_content` and everything else into the `metadata`.
"""
Expand All @@ -110,12 +111,12 @@ class AlloyDBLoader(BaseLoader):

def __init__(
self,
key,
key: object,
engine: AlloyDBEngine,
query: str,
content_columns: List[str],
metadata_columns: List[str],
formatter: Callable,
formatter: Callable[[Dict[str, Any], Iterable[str]], str],
metadata_json_column: Optional[str] = None,
) -> None:
if key != AlloyDBLoader.__create_key:
Expand All @@ -132,7 +133,7 @@ def __init__(

@classmethod
async def create(
cls,
cls: Type[AlloyDBLoader],
engine: AlloyDBEngine,
query: Optional[str] = None,
table_name: Optional[str] = None,
Expand All @@ -141,7 +142,7 @@ async def create(
metadata_json_column: Optional[str] = None,
format: Optional[str] = None,
formatter: Optional[Callable] = None,
):
) -> AlloyDBLoader:
"""Constructor for AlloyDBLoader

Args:
Expand Down Expand Up @@ -225,7 +226,7 @@ async def create(

@classmethod
def create_sync(
cls,
cls: Type[AlloyDBLoader],
engine: AlloyDBEngine,
query: Optional[str] = None,
table_name: Optional[str] = None,
Expand All @@ -234,7 +235,7 @@ def create_sync(
metadata_json_column: Optional[str] = None,
format: Optional[str] = None,
formatter: Optional[Callable] = None,
):
) -> AlloyDBLoader:
coro = cls.create(
engine,
query,
Expand All @@ -247,18 +248,20 @@ def create_sync(
)
return engine._run_as_sync(coro)

async def _collect_async_items(self, docs_generator):
async def _collect_async_items(
self, docs_generator: AsyncIterator[Document]
) -> List[Document]:
return [doc async for doc in docs_generator]

def load(self) -> List[Document]:
"""Load Alloydb data into Document objects."""
"""Load AlloyDB data into Document objects."""
documents = self.engine._run_as_sync(
self._collect_async_items(self.alazy_load())
)
return documents

async def aload(self) -> List[Document]:
"""Load Alloydb data into Document objects."""
"""Load AlloyDB data into Document objects."""
return [doc async for doc in self.alazy_load()]

def lazy_load(self) -> Iterator[Document]:
Expand Down Expand Up @@ -303,13 +306,13 @@ class AlloyDBDocumentSaver:

def __init__(
self,
key,
key: object,
engine: AlloyDBEngine,
table_name: str,
content_column: str,
metadata_columns: List[str] = [],
metadata_json_column: Optional[str] = None,
):
) -> None:
if key != AlloyDBDocumentSaver.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
Expand All @@ -322,13 +325,13 @@ def __init__(

@classmethod
async def create(
cls,
cls: Type[AlloyDBDocumentSaver],
engine: AlloyDBEngine,
table_name: str,
content_column: str = DEFAULT_CONTENT_COL,
metadata_columns: List[str] = [],
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
):
) -> AlloyDBDocumentSaver:
table_schema = await engine._aload_table_schema(table_name)
column_names = table_schema.columns.keys()
if content_column not in column_names:
Expand Down Expand Up @@ -367,13 +370,13 @@ async def create(

@classmethod
def create_sync(
cls,
cls: Type[AlloyDBDocumentSaver],
engine: AlloyDBEngine,
table_name: str,
content_column: str = DEFAULT_CONTENT_COL,
metadata_columns: List[str] = [],
metadata_json_column: str = DEFAULT_METADATA_COL,
):
) -> AlloyDBDocumentSaver:
coro = cls.create(
engine,
table_name,
Expand Down
Loading