diff --git a/src/everyrow/__init__.py b/src/everyrow/__init__.py index d0bf738e..91d90618 100644 --- a/src/everyrow/__init__.py +++ b/src/everyrow/__init__.py @@ -1,4 +1,5 @@ from everyrow.api_utils import create_client from everyrow.session import create_session +from everyrow.task import fetch_task_data -__all__ = ["create_client", "create_session"] +__all__ = ["create_client", "create_session", "fetch_task_data"] diff --git a/src/everyrow/task.py b/src/everyrow/task.py index 06c60c4d..2c122828 100644 --- a/src/everyrow/task.py +++ b/src/everyrow/task.py @@ -5,7 +5,7 @@ from pandas import DataFrame from pydantic.main import BaseModel -from everyrow.api_utils import handle_response +from everyrow.api_utils import create_client, handle_response from everyrow.citations import render_citations_group, render_citations_standalone from everyrow.constants import EveryrowError from everyrow.generated.api.default import ( @@ -141,3 +141,48 @@ async def read_scalar_result[T: BaseModel]( artifact = render_citations_standalone(artifact) return response_model(**artifact.data) + + +async def fetch_task_data( + task_id: UUID | str, + client: AuthenticatedClient | None = None, +) -> DataFrame: + """Fetch the result data for a completed task as a pandas DataFrame. + + This is a convenience helper that retrieves the table-level group artifact + associated with a task and returns it as a DataFrame. + + Args: + task_id: The UUID of the task to fetch data for (can be a string or UUID). + client: Optional authenticated client. If not provided, one will be created + using the EVERYROW_API_KEY environment variable. + + Returns: + A pandas DataFrame containing the task result data. + + Raises: + EveryrowError: If the task has not completed, failed, or has no artifact. + + Example: + >>> from everyrow import fetch_task_data + >>> df = await fetch_task_data("12345678-1234-1234-1234-123456789abc") + >>> print(df.head()) + """ + if isinstance(task_id, str): + task_id = UUID(task_id) + + if client is None: + client = create_client() + + status_response = await get_task_status(task_id, client) + + if status_response.status not in (TaskStatus.COMPLETED,): + raise EveryrowError( + f"Task {task_id} is not completed (status: {status_response.status.value}). " + f"Error: {status_response.error}" + ) + + if status_response.artifact_id is None: + raise EveryrowError(f"Task {task_id} has no associated artifact.") + + return await read_table_result(status_response.artifact_id, client)