Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/graphrag-storage/graphrag_storage/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ async def get_creation_date(self, key: str) -> str:

return get_timestamp_formatted_with_local_tz(creation_time_utc)

def get_path(self, key: str) -> Path:
"""Get the full file path for a key (for streaming access)."""
return _join_path(self._base_dir, key)


def _join_path(file_path: Path, file_name: str) -> Path:
"""Join a path and a file. Independent of the OS."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Table provider module for GraphRAG storage."""

from .table import Table
from .table_provider import TableProvider

__all__ = ["TableProvider"]
__all__ = ["Table", "TableProvider"]
144 changes: 144 additions & 0 deletions packages/graphrag-storage/graphrag_storage/tables/csv_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT Licenses

"""A CSV-based implementation of the Table abstraction for streaming row access."""

from __future__ import annotations

import csv
import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any

import aiofiles

from graphrag_storage.file_storage import FileStorage
from graphrag_storage.tables.table import RowTransformer, Table

if TYPE_CHECKING:
from collections.abc import AsyncIterator
from io import TextIOWrapper

from graphrag_storage import Storage


def _identity(row: dict[str, Any]) -> Any:
"""Return row unchanged (default transformer)."""
return row


def _apply_transformer(transformer: RowTransformer, row: dict[str, Any]) -> Any:
"""Apply transformer to row, handling both callables and classes.

If transformer is a class (e.g., Pydantic model), calls it with **row.
Otherwise calls it with row as positional argument.
"""
if inspect.isclass(transformer):
return transformer(**row)
return transformer(row)


class CSVTable(Table):
"""Row-by-row streaming interface for CSV tables."""

def __init__(
self,
storage: Storage,
table_name: str,
transformer: RowTransformer | None = None,
):
"""Initialize with storage backend and table name.

Args:
storage: Storage instance (File, Blob, or Cosmos)
table_name: Name of the table (e.g., "documents")
transformer: Optional callable to transform each row before
yielding. Receives a dict, returns a transformed dict.
Defaults to identity (no transformation).
"""
self._storage = storage
self._table_name = table_name
self._file_key = f"{table_name}.csv"
self._transformer = transformer or _identity
self._write_file: TextIOWrapper | None = None
self._writer: csv.DictWriter | None = None
self._header_written = False

def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate through rows one at a time.

The transformer is applied to each row before yielding.
If transformer is a Pydantic model, yields model instances.

Yields
------
Any:
Each row as dict or transformed type (e.g., Pydantic model).
"""
return self._aiter_impl()

async def _aiter_impl(self) -> AsyncIterator[Any]:
"""Implement async iteration over rows."""
if isinstance(self._storage, FileStorage):
file_path = self._storage.get_path(self._file_key)
with Path.open(file_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
yield _apply_transformer(self._transformer, row)

async def length(self) -> int:
"""Return the number of rows in the table."""
if isinstance(self._storage, FileStorage):
file_path = self._storage.get_path(self._file_key)
count = 0
async with aiofiles.open(file_path, "rb") as f:
while True:
chunk = await f.read(65536)
if not chunk:
break
count += chunk.count(b"\n")
return count - 1
return 0

async def has(self, row_id: str) -> bool:
"""Check if row with given ID exists."""
async for row in self:
# Handle both dict and object (e.g., Pydantic model)
if isinstance(row, dict):
if row.get("id") == row_id:
return True
elif getattr(row, "id", None) == row_id:
return True
return False

async def write(self, row: dict[str, Any]) -> None:
"""Write a single row to the CSV file.

On first write, opens the file and writes the header row.
Subsequent writes append rows to the file.

Args
----
row: Dictionary representing a single row to write.
"""
if isinstance(self._storage, FileStorage) and self._write_file is None:
file_path = self._storage.get_path(self._file_key)
file_path.parent.mkdir(parents=True, exist_ok=True)
self._write_file = Path.open(file_path, "w", encoding="utf-8", newline="")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be "a" for append mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it? this is only opening once in the instance, then .writerow()... so each instance should create the file from scratch, no?

self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys()))
self._writer.writeheader()
self._header_written = True

if self._writer is not None:
self._writer.writerow(row)

async def close(self) -> None:
"""Flush buffered writes and release resources.

Closes the file handle if writing was performed.
"""
if self._write_file is not None:
self._write_file.close()
self._write_file = None
self._writer = None
self._header_written = False
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License

"""CSV-based table provider implementation."""
Expand All @@ -9,7 +9,10 @@

import pandas as pd

from graphrag_storage.file_storage import FileStorage
from graphrag_storage.storage import Storage
from graphrag_storage.tables.csv_table import CSVTable
from graphrag_storage.tables.table import RowTransformer
from graphrag_storage.tables.table_provider import TableProvider

logger = logging.getLogger(__name__)
Expand All @@ -32,6 +35,9 @@ def __init__(self, storage: Storage, **kwargs) -> None:
**kwargs: Any
Additional keyword arguments (currently unused).
"""
if not isinstance(storage, FileStorage):
msg = "CSVTableProvider only works with FileStorage backends for now. "
raise TypeError(msg)
self._storage = storage

async def read_dataframe(self, table_name: str) -> pd.DataFrame:
Expand Down Expand Up @@ -108,3 +114,9 @@ def list(self) -> list[str]:
file.replace(".csv", "")
for file in self._storage.find(re.compile(r"\.csv$"))
]

def open(
self, table_name: str, transformer: RowTransformer | None = None
) -> CSVTable:
"""Open table for streaming."""
return CSVTable(self._storage, table_name, transformer=transformer)
141 changes: 141 additions & 0 deletions packages/graphrag-storage/graphrag_storage/tables/parquet_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (C) 2025 Microsoft
# Licensed under the MIT License

"""A Parquet-based implementation of the Table abstraction with simulated streaming."""

from __future__ import annotations

import inspect
from io import BytesIO
from typing import TYPE_CHECKING, Any, cast

import pandas as pd

from graphrag_storage.tables.table import RowTransformer, Table

if TYPE_CHECKING:
from collections.abc import AsyncIterator

from graphrag_storage.storage import Storage


def _identity(row: dict[str, Any]) -> Any:
"""Return row unchanged (default transformer)."""
return row


def _apply_transformer(transformer: RowTransformer, row: dict[str, Any]) -> Any:
"""Apply transformer to row, handling both callables and classes.

If transformer is a class (e.g., Pydantic model), calls it with **row.
Otherwise calls it with row as positional argument.
"""
if inspect.isclass(transformer):
return transformer(**row)
return transformer(row)


class ParquetTable(Table):
"""Simulated streaming interface for Parquet tables.

Parquet format doesn't support true row-by-row streaming, so this
implementation simulates streaming via:
- Read: Loads DataFrame, yields rows via iterrows()
- Write: Accumulates rows in memory, writes all at once on close()

This provides API compatibility with CSVTable while maintaining
Parquet's performance characteristics for bulk operations.
"""

def __init__(
self,
storage: Storage,
table_name: str,
transformer: RowTransformer | None = None,
):
"""Initialize with storage backend and table name.

Args:
storage: Storage instance (File, Blob, or Cosmos)
table_name: Name of the table (e.g., "documents")
transformer: Optional callable to transform each row before
yielding. Receives a dict, returns a transformed dict.
Defaults to identity (no transformation).
"""
self._storage = storage
self._table_name = table_name
self._file_key = f"{table_name}.parquet"
self._transformer = transformer or _identity
self._df: pd.DataFrame | None = None
self._write_rows: list[dict[str, Any]] = []

def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate through rows one at a time.

Loads the entire DataFrame on first iteration, then yields rows
one at a time with the transformer applied.

Yields
------
Any:
Each row as dict or transformed type (e.g., Pydantic model).
"""
return self._aiter_impl()

async def _aiter_impl(self) -> AsyncIterator[Any]:
"""Implement async iteration over rows."""
if self._df is None:
if await self._storage.has(self._file_key):
data = await self._storage.get(self._file_key, as_bytes=True)
self._df = pd.read_parquet(BytesIO(data))
else:
self._df = pd.DataFrame()

for _, row in self._df.iterrows():
row_dict = cast("dict[str, Any]", row.to_dict())
yield _apply_transformer(self._transformer, row_dict)

async def length(self) -> int:
"""Return the number of rows in the table."""
if self._df is None:
if await self._storage.has(self._file_key):
data = await self._storage.get(self._file_key, as_bytes=True)
self._df = pd.read_parquet(BytesIO(data))
else:
return 0
return len(self._df)

async def has(self, row_id: str) -> bool:
"""Check if row with given ID exists."""
async for row in self:
if isinstance(row, dict):
if row.get("id") == row_id:
return True
elif getattr(row, "id", None) == row_id:
return True
return False

async def write(self, row: dict[str, Any]) -> None:
"""Accumulate a single row for later batch write.

Rows are stored in memory and written to Parquet format
when close() is called.

Args
----
row: Dictionary representing a single row to write.
"""
self._write_rows.append(row)

async def close(self) -> None:
"""Flush accumulated rows to Parquet file and release resources.

Converts all accumulated rows to a DataFrame and writes
to storage as a Parquet file.
"""
if self._write_rows:
df = pd.DataFrame(self._write_rows)
await self._storage.set(self._file_key, df.to_parquet())
self._write_rows = []

self._df = None
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pandas as pd

from graphrag_storage.storage import Storage
from graphrag_storage.tables.parquet_table import ParquetTable
from graphrag_storage.tables.table import RowTransformer, Table
from graphrag_storage.tables.table_provider import TableProvider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,3 +108,23 @@ def list(self) -> list[str]:
file.replace(".parquet", "")
for file in self._storage.find(re.compile(r"\.parquet$"))
]

def open(self, table_name: str, transformer: RowTransformer | None = None) -> Table:
"""Open a table for streaming row operations.

Returns a ParquetTable that simulates streaming by loading the
DataFrame and iterating rows, or accumulating writes for batch output.

Args
----
table_name: str
The name of the table to open.
transformer: RowTransformer | None
Optional callable to transform each row on read.

Returns
-------
Table:
A ParquetTable instance for row-by-row access.
"""
return ParquetTable(self._storage, table_name, transformer)
Loading