diff --git a/Makefile b/Makefile index 67b76e40a..7b9a8469e 100644 --- a/Makefile +++ b/Makefile @@ -51,6 +51,9 @@ api-docs: ingest: @poetry run python scripts/ingest_folder.py $(call args) +stats: + poetry run python scripts/utils.py stats + wipe: poetry run python scripts/utils.py wipe diff --git a/scripts/utils.py b/scripts/utils.py index 48068789c..92fd5d8ca 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,26 +1,12 @@ import argparse import os import shutil +from typing import Any, ClassVar from private_gpt.paths import local_data_path from private_gpt.settings.settings import settings -def wipe() -> None: - WIPE_MAP = { - "simple": wipe_simple, # node store - "chroma": wipe_chroma, # vector store - "postgres": wipe_postgres, # node, index and vector store - } - for dbtype in ("nodestore", "vectorstore"): - database = getattr(settings(), dbtype).database - func = WIPE_MAP.get(database) - if func: - func(dbtype) - else: - print(f"Unable to wipe database '{database}' for '{dbtype}'") - - def wipe_file(file: str) -> None: if os.path.isfile(file): os.remove(file) @@ -50,62 +36,149 @@ def wipe_tree(path: str) -> None: continue -def wipe_simple(dbtype: str) -> None: - assert dbtype == "nodestore" - from llama_index.core.storage.docstore.types import ( - DEFAULT_PERSIST_FNAME as DOCSTORE, - ) - from llama_index.core.storage.index_store.types import ( - DEFAULT_PERSIST_FNAME as INDEXSTORE, - ) +class Postgres: + tables: ClassVar[dict[str, list[str]]] = { + "nodestore": ["data_docstore", "data_indexstore"], + "vectorstore": ["data_embeddings"], + } + + def __init__(self) -> None: + try: + import psycopg2 + except ModuleNotFoundError: + raise ModuleNotFoundError("Postgres dependencies not found") from None - for store in (DOCSTORE, INDEXSTORE): - wipe_file(str((local_data_path / store).absolute())) + connection = settings().postgres.model_dump(exclude_none=True) + self.schema = connection.pop("schema_name") + self.conn = psycopg2.connect(**connection) + def wipe(self, storetype: str) -> None: + cur = self.conn.cursor() + try: + for table in self.tables[storetype]: + sql = f"DROP TABLE IF EXISTS {self.schema}.{table}" + cur.execute(sql) + print(f"Table {self.schema}.{table} dropped.") + self.conn.commit() + finally: + cur.close() -def wipe_postgres(dbtype: str) -> None: - try: - import psycopg2 - except ImportError as e: - raise ImportError("Postgres dependencies not found") from e + def stats(self, store_type: str) -> None: + template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}" + sql = " UNION ALL ".join( + template.format(table=tbl) for tbl in self.tables[store_type] + ) + + cur = self.conn.cursor() + try: + print(f"Storage for Postgres {store_type}.") + print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size")) + print("-" * 45) # Print a line separator - cur = conn = None - try: - tables = { - "nodestore": ["data_docstore", "data_indexstore"], - "vectorstore": ["data_embeddings"], - }[dbtype] - connection = settings().postgres.model_dump(exclude_none=True) - schema = connection.pop("schema_name") - conn = psycopg2.connect(**connection) - cur = conn.cursor() - for table in tables: - sql = f"DROP TABLE IF EXISTS {schema}.{table}" cur.execute(sql) - print(f"Table {schema}.{table} dropped.") - conn.commit() - except psycopg2.Error as e: - print("Error:", e) - finally: - if cur: + for row in cur.fetchall(): + formatted_row_count = f"{row[1]:,}" + print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}") + + print() + finally: cur.close() - if conn: - conn.close() + def __del__(self): + if hasattr(self, "conn") and self.conn: + self.conn.close() -def wipe_chroma(dbtype: str): - assert dbtype == "vectorstore" - wipe_tree(str((local_data_path / "chroma_db").absolute())) +class Simple: + def wipe(self, store_type: str) -> None: + assert store_type == "nodestore" + from llama_index.core.storage.docstore.types import ( + DEFAULT_PERSIST_FNAME as DOCSTORE, + ) + from llama_index.core.storage.index_store.types import ( + DEFAULT_PERSIST_FNAME as INDEXSTORE, + ) -if __name__ == "__main__": - commands = { - "wipe": wipe, + for store in (DOCSTORE, INDEXSTORE): + wipe_file(str((local_data_path / store).absolute())) + + +class Chroma: + def wipe(self, store_type: str) -> None: + assert store_type == "vectorstore" + wipe_tree(str((local_data_path / "chroma_db").absolute())) + + +class Qdrant: + COLLECTION = ( + "make_this_parameterizable_per_api_call" # ?! see vector_store_component.py + ) + + def __init__(self) -> None: + try: + from qdrant_client import QdrantClient # type: ignore + except ImportError: + raise ImportError("Qdrant dependencies not found") from None + self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True)) + + def wipe(self, store_type: str) -> None: + assert store_type == "vectorstore" + try: + self.client.delete_collection(self.COLLECTION) + print("Collection dropped successfully.") + except Exception as e: + print("Error dropping collection:", e) + + def stats(self, store_type: str) -> None: + print(f"Storage for Qdrant {store_type}.") + try: + collection_data = self.client.get_collection(self.COLLECTION) + if collection_data: + # Collection Info + # https://qdrant.tech/documentation/concepts/collections/ + print(f"\tPoints: {collection_data.points_count:,}") + print(f"\tVectors: {collection_data.vectors_count:,}") + print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}") + return + except ValueError: + pass + print("\t- Qdrant collection not found or empty") + + +class Command: + DB_HANDLERS: ClassVar[dict[str, Any]] = { + "simple": Simple, # node store + "chroma": Chroma, # vector store + "postgres": Postgres, # node, index and vector store + "qdrant": Qdrant, # vector store } + def for_each_store(self, cmd: str): + for store_type in ("nodestore", "vectorstore"): + database = getattr(settings(), store_type).database + handler_class = self.DB_HANDLERS.get(database) + if handler_class is None: + print(f"No handler found for database '{database}'") + continue + handler_instance = handler_class() # Instantiate the class + # If the DB can handle this cmd dispatch it. + if hasattr(handler_instance, cmd) and callable( + func := getattr(handler_instance, cmd) + ): + func(store_type) + else: + print( + f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'" + ) + + def execute(self, cmd: str) -> None: + if cmd in ("wipe", "stats"): + self.for_each_store(cmd) + + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "mode", help="select a mode to run", choices=list(commands.keys()) - ) + parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"]) args = parser.parse_args() - commands[args.mode.lower()]() + + Command().execute(args.mode.lower())