Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,10 @@
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "always"
}
},
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
55 changes: 1 addition & 54 deletions src/source_msgraph/async_interator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,10 @@
import asyncio
from typing import AsyncGenerator, Iterator, Any

class AsyncToSyncIterator:
"""
Converts an async generator into a synchronous iterator while ensuring proper event loop handling.
"""

def __init__(self, async_gen: AsyncGenerator[Any, None]):
"""
Initializes the iterator by consuming an async generator synchronously.

Args:
async_gen (AsyncGenerator): The async generator yielding results.
"""
self.async_gen = async_gen
self.loop = self._get_event_loop()
self.iterator = self._to_iterator()

def _get_event_loop(self) -> asyncio.AbstractEventLoop:
"""Returns the currently running event loop or creates a new one if none exists."""
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None # Indicate an already running loop (handled in `_to_iterator()`)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

def _to_iterator(self) -> Iterator:
"""
Ensures that the async generator is consumed using the correct event loop.
Uses streaming (does not load all results into memory).
"""
if self.loop:
return iter(self.loop.run_until_complete(self._stream_results()))
else:
return iter(asyncio.run(self._stream_results())) # Safe for Jupyter, PySpark

# Caution : prone to OOM errors
async def _stream_results(self):
# """Streams async generator results without collecting all in memory."""
# page_count = 0
# async for item in self.async_gen:
# if page_count >= self.max_pages:
# raise RuntimeError("Pagination limit reached, possible infinite loop detected!")
# yield item
# page_count += 1 # Track pages to prevent infinite loops
return [item async for item in self.async_gen]

def __iter__(self) -> Iterator:
"""Returns the synchronous iterator."""
return self.iterator


import asyncio
from typing import AsyncGenerator, Iterator, Any

class AsyncToSyncIteratorV2:
class AsyncToSyncIterator:
"""
Converts an async generator into a synchronous iterator while ensuring proper event loop handling.
"""
Expand Down
10 changes: 4 additions & 6 deletions src/source_msgraph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from kiota_abstractions.base_request_configuration import RequestConfiguration
from msgraph.generated.models.o_data_errors.o_data_error import ODataError
from azure.identity import ClientSecretCredential
from source_msgraph.async_interator import AsyncToSyncIterator, AsyncToSyncIteratorV2
from source_msgraph.async_interator import AsyncToSyncIterator
from source_msgraph.models import ConnectorOptions
from source_msgraph.utils import get_python_schema, to_json, to_pyspark_schema
from typing import Dict, Any

class GraphClient:
def __init__(self, options: ConnectorOptions):
Expand Down Expand Up @@ -48,7 +49,6 @@ async def fetch_data(self):
builder = self.options.resource.get_request_builder_cls()(self.graph_client.request_adapter, self.options.resource.resource_params)
items = await builder.get(request_configuration=request_configuration)
while True:
print("Page fetched....")
for item in items.value:
yield item
if not items.odata_next_link:
Expand All @@ -72,9 +72,7 @@ def iter_records(options: ConnectorOptions):
async_gen = fetcher.fetch_data()
return AsyncToSyncIterator(async_gen)

import json
from typing import Dict, Any
from dataclasses import asdict


def get_resource_schema(options: ConnectorOptions) -> Dict[str, Any]:
"""
Expand All @@ -89,7 +87,7 @@ def get_resource_schema(options: ConnectorOptions) -> Dict[str, Any]:
async_gen = fetcher.fetch_data()

try:
record = next(AsyncToSyncIteratorV2(async_gen), None)
record = next(AsyncToSyncIterator(async_gen), None)
if not record:
raise ValueError(f"No records found for resource: {options.resource.resource_name}")
record = to_json(record)
Expand Down
5 changes: 3 additions & 2 deletions src/source_msgraph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from typing import Any, Dict
from source_msgraph.constants import MSGRAPH_SDK_PACKAGE
from urllib.parse import unquote, quote
from urllib.parse import unquote
from kiota_abstractions.base_request_builder import BaseRequestBuilder

@dataclass
Expand Down Expand Up @@ -133,8 +133,9 @@ def map_options_to_params(self, options: Dict[str, Any]) -> 'BaseResource':
if missing_params:
raise ValueError(f"Missing required resource parameters: {', '.join(missing_params)}")

# TODO: add max $top value validation.

mapped_query_params = {"%24"+k: v for k, v in options.items() if k in self.query_params}

mapped_resource_params = {k.replace("-", "%2D"): v for k, v in options.items() if k in self.resource_params}

invalid_params = {k: v for k, v in options.items() if k not in self.query_params and k not in self.resource_params}
Expand Down
5 changes: 4 additions & 1 deletion src/source_msgraph/source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Dict, Union
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType
Expand All @@ -7,6 +8,7 @@
from source_msgraph.resources import get_resource
# Reference https://learn.microsoft.com/en-us/azure/databricks/pyspark/datasources

logger = logging.getLogger(__name__)

class MSGraphDataSource(DataSource):
"""
Expand Down Expand Up @@ -37,8 +39,9 @@ def name(cls):
return "msgraph"

def schema(self):
print("getting aschema")
logger.info("Schema not provided, infering from the source.")
_, schema = get_resource_schema(self.connector_options)
logger.debug(f"Infered schema : {schema}")
return schema

def reader(self, schema: StructType):
Expand Down