diff --git a/.vscode/settings.json b/.vscode/settings.json index 660ca61..a899567 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,10 @@ "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.organizeImports": "always" - } + }, + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/src/source_msgraph/async_interator.py b/src/source_msgraph/async_interator.py index 2069c2a..b2c121a 100644 --- a/src/source_msgraph/async_interator.py +++ b/src/source_msgraph/async_interator.py @@ -3,63 +3,10 @@ import asyncio from typing import AsyncGenerator, Iterator, Any -class AsyncToSyncIterator: - """ - Converts an async generator into a synchronous iterator while ensuring proper event loop handling. - """ - - def __init__(self, async_gen: AsyncGenerator[Any, None]): - """ - Initializes the iterator by consuming an async generator synchronously. - - Args: - async_gen (AsyncGenerator): The async generator yielding results. - """ - self.async_gen = async_gen - self.loop = self._get_event_loop() - self.iterator = self._to_iterator() - - def _get_event_loop(self) -> asyncio.AbstractEventLoop: - """Returns the currently running event loop or creates a new one if none exists.""" - try: - loop = asyncio.get_running_loop() - if loop.is_running(): - return None # Indicate an already running loop (handled in `_to_iterator()`) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - def _to_iterator(self) -> Iterator: - """ - Ensures that the async generator is consumed using the correct event loop. - Uses streaming (does not load all results into memory). - """ - if self.loop: - return iter(self.loop.run_until_complete(self._stream_results())) - else: - return iter(asyncio.run(self._stream_results())) # Safe for Jupyter, PySpark - - # Caution : prone to OOM errors - async def _stream_results(self): - # """Streams async generator results without collecting all in memory.""" - # page_count = 0 - # async for item in self.async_gen: - # if page_count >= self.max_pages: - # raise RuntimeError("Pagination limit reached, possible infinite loop detected!") - # yield item - # page_count += 1 # Track pages to prevent infinite loops - return [item async for item in self.async_gen] - - def __iter__(self) -> Iterator: - """Returns the synchronous iterator.""" - return self.iterator - - import asyncio from typing import AsyncGenerator, Iterator, Any -class AsyncToSyncIteratorV2: +class AsyncToSyncIterator: """ Converts an async generator into a synchronous iterator while ensuring proper event loop handling. """ diff --git a/src/source_msgraph/client.py b/src/source_msgraph/client.py index 3b888b6..b429073 100644 --- a/src/source_msgraph/client.py +++ b/src/source_msgraph/client.py @@ -2,9 +2,10 @@ from kiota_abstractions.base_request_configuration import RequestConfiguration from msgraph.generated.models.o_data_errors.o_data_error import ODataError from azure.identity import ClientSecretCredential -from source_msgraph.async_interator import AsyncToSyncIterator, AsyncToSyncIteratorV2 +from source_msgraph.async_interator import AsyncToSyncIterator from source_msgraph.models import ConnectorOptions from source_msgraph.utils import get_python_schema, to_json, to_pyspark_schema +from typing import Dict, Any class GraphClient: def __init__(self, options: ConnectorOptions): @@ -48,7 +49,6 @@ async def fetch_data(self): builder = self.options.resource.get_request_builder_cls()(self.graph_client.request_adapter, self.options.resource.resource_params) items = await builder.get(request_configuration=request_configuration) while True: - print("Page fetched....") for item in items.value: yield item if not items.odata_next_link: @@ -72,9 +72,7 @@ def iter_records(options: ConnectorOptions): async_gen = fetcher.fetch_data() return AsyncToSyncIterator(async_gen) -import json -from typing import Dict, Any -from dataclasses import asdict + def get_resource_schema(options: ConnectorOptions) -> Dict[str, Any]: """ @@ -89,7 +87,7 @@ def get_resource_schema(options: ConnectorOptions) -> Dict[str, Any]: async_gen = fetcher.fetch_data() try: - record = next(AsyncToSyncIteratorV2(async_gen), None) + record = next(AsyncToSyncIterator(async_gen), None) if not record: raise ValueError(f"No records found for resource: {options.resource.resource_name}") record = to_json(record) diff --git a/src/source_msgraph/models.py b/src/source_msgraph/models.py index 08a834f..1f5c046 100644 --- a/src/source_msgraph/models.py +++ b/src/source_msgraph/models.py @@ -4,7 +4,7 @@ import re from typing import Any, Dict from source_msgraph.constants import MSGRAPH_SDK_PACKAGE -from urllib.parse import unquote, quote +from urllib.parse import unquote from kiota_abstractions.base_request_builder import BaseRequestBuilder @dataclass @@ -133,8 +133,9 @@ def map_options_to_params(self, options: Dict[str, Any]) -> 'BaseResource': if missing_params: raise ValueError(f"Missing required resource parameters: {', '.join(missing_params)}") + # TODO: add max $top value validation. + mapped_query_params = {"%24"+k: v for k, v in options.items() if k in self.query_params} - mapped_resource_params = {k.replace("-", "%2D"): v for k, v in options.items() if k in self.resource_params} invalid_params = {k: v for k, v in options.items() if k not in self.query_params and k not in self.resource_params} diff --git a/src/source_msgraph/source.py b/src/source_msgraph/source.py index d02383f..7366e3d 100644 --- a/src/source_msgraph/source.py +++ b/src/source_msgraph/source.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, Union from pyspark.sql.datasource import DataSource, DataSourceReader from pyspark.sql.types import StructType @@ -7,6 +8,7 @@ from source_msgraph.resources import get_resource # Reference https://learn.microsoft.com/en-us/azure/databricks/pyspark/datasources +logger = logging.getLogger(__name__) class MSGraphDataSource(DataSource): """ @@ -37,8 +39,9 @@ def name(cls): return "msgraph" def schema(self): - print("getting aschema") + logger.info("Schema not provided, infering from the source.") _, schema = get_resource_schema(self.connector_options) + logger.debug(f"Infered schema : {schema}") return schema def reader(self, schema: StructType):