diff --git a/cookbook/cds_discharge_summarizer_hf_chat.py b/cookbook/cds_discharge_summarizer_hf_chat.py index ad0a15c0..e1b6b528 100644 --- a/cookbook/cds_discharge_summarizer_hf_chat.py +++ b/cookbook/cds_discharge_summarizer_hf_chat.py @@ -81,12 +81,11 @@ def start_api(): # Create sandbox client and load test data client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/discharge-summarizer", + url="http://localhost:8000/cds/cds-services/discharge-summarizer", + workflow="encounter-discharge", ) # Load discharge notes from CSV client.load_free_text( - workflow="encounter-discharge", csv_path="data/discharge_notes.csv", column_name="text", ) diff --git a/cookbook/cds_discharge_summarizer_hf_trf.py b/cookbook/cds_discharge_summarizer_hf_trf.py index 2cb49baa..6d08332c 100644 --- a/cookbook/cds_discharge_summarizer_hf_trf.py +++ b/cookbook/cds_discharge_summarizer_hf_trf.py @@ -54,12 +54,11 @@ def start_api(): # Create sandbox client and load test data client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/discharge-summarizer", + url="http://localhost:8000/cds/cds-services/discharge-summarizer", + workflow="encounter-discharge", ) # Load discharge notes from CSV client.load_free_text( - workflow="encounter-discharge", csv_path="data/discharge_notes.csv", column_name="text", ) diff --git a/cookbook/notereader_clinical_coding_fhir.py b/cookbook/notereader_clinical_coding_fhir.py index 76e9e8b3..24f1f343 100644 --- a/cookbook/notereader_clinical_coding_fhir.py +++ b/cookbook/notereader_clinical_coding_fhir.py @@ -132,7 +132,9 @@ def run_server(): # Create sandbox client for testing client = SandboxClient( - api_url="http://localhost:8000", endpoint="/notereader/fhir/", protocol="soap" + url="http://localhost:8000/notereader/fhir/", + workflow="sign-note-inpatient", + protocol="soap", ) # Load clinical document from file client.load_from_path("./data/notereader_cda.xml") diff --git a/docs/cookbook/clinical_coding.md b/docs/cookbook/clinical_coding.md index 9ea0b277..cd4bf482 100644 --- a/docs/cookbook/clinical_coding.md +++ b/docs/cookbook/clinical_coding.md @@ -224,8 +224,7 @@ from healthchain.sandbox import SandboxClient # Create sandbox client for SOAP/CDA testing client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/notereader/ProcessDocument", + url="http://localhost:8000/notereader/ProcessDocument", workflow="sign-note-inpatient", protocol="soap" ) diff --git a/docs/cookbook/discharge_summarizer.md b/docs/cookbook/discharge_summarizer.md index 96a9a7af..1af15122 100644 --- a/docs/cookbook/discharge_summarizer.md +++ b/docs/cookbook/discharge_summarizer.md @@ -159,16 +159,14 @@ from healthchain.sandbox import SandboxClient # Create sandbox client for testing client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/discharge-summarizer", + url="http://localhost:8000/cds/cds-services/discharge-summarizer", workflow="encounter-discharge" ) # Load discharge notes from CSV and generate FHIR data client.load_free_text( csv_path="data/discharge_notes.csv", - column_name="text", - workflow="encounter-discharge" + column_name="text" ) ``` diff --git a/docs/quickstart.md b/docs/quickstart.md index 2a93fe66..be40bbb1 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -185,15 +185,19 @@ Test your AI applications in realistic healthcare contexts with `SandboxClient` ```python from healthchain.sandbox import SandboxClient -# Create client and load test data +# Create client with service URL and workflow client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/my-service", + url="http://localhost:8000/cds/cds-services/my-service", workflow="encounter-discharge" ) # Load from datasets or files -client.load_from_registry("synthea", num_patients=5) +client.load_from_registry( + "synthea-patient", + data_dir="./data/synthea", + resource_types=["Condition", "DocumentReference"], + sample_size=3 +) responses = client.send_requests() ``` diff --git a/docs/reference/utilities/data_generator.md b/docs/reference/utilities/data_generator.md index a7324cf8..33958143 100644 --- a/docs/reference/utilities/data_generator.md +++ b/docs/reference/utilities/data_generator.md @@ -10,7 +10,7 @@ According to the [UK ONS synthetic data classification](https://www.ons.gov.uk/m ## CDS Data Generator -The `.generate_prefetch()` method will return a `Prefetch` model with the `prefetch` field populated with a dictionary of FHIR resources. Each key in the dictionary corresponds to a FHIR resource type, and the value is a list of FHIR resources of that type. For more information, check out the [CDS Hooks documentation](https://cds-hooks.org/specification/current/#providing-fhir-resources-to-a-cds-service). +The `.generate_prefetch()` method will return a dictionary of resources. Each key in the dictionary corresponds to a FHIR resource type, and the value is a list of FHIR resources or a Bundle of that type. For more information, check out the [CDS Hooks documentation](https://cds-hooks.org/specification/current/#providing-fhir-resources-to-a-cds-service). For each workflow, a pre-configured list of FHIR resources is randomly generated and placed in the `prefetch` field of a `CDSRequest`. @@ -33,8 +33,7 @@ You can use the data generator with `SandboxClient.load_free_text()` or standalo # Create client client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/my-service", + url="http://localhost:8000/cds/cds-services/my-service", workflow="encounter-discharge" ) @@ -42,7 +41,6 @@ You can use the data generator with `SandboxClient.load_free_text()` or standalo client.load_free_text( csv_path="./data/discharge_notes.csv", column_name="text", - workflow="encounter-discharge", random_seed=42 ) diff --git a/docs/reference/utilities/sandbox.md b/docs/reference/utilities/sandbox.md index b1d1a6a1..a64bcaf3 100644 --- a/docs/reference/utilities/sandbox.md +++ b/docs/reference/utilities/sandbox.md @@ -9,15 +9,19 @@ Test CDS Hooks workflows with synthetic data: ```python from healthchain.sandbox import SandboxClient -# Create client +# Create client with full service URL and workflow client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/my-service", + url="http://localhost:8000/cds/cds-services/my-service", workflow="encounter-discharge" ) # Load data and send requests -client.load_from_registry("synthea", num_patients=5) +client.load_from_registry( + "synthea-patient", + data_dir="./data/synthea", + resource_types=["Condition", "MedicationStatement"], + sample_size=5 + ) responses = client.send_requests() ``` @@ -29,25 +33,36 @@ responses = client.send_requests() from healthchain.sandbox import SandboxClient client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/my-service", - workflow="encounter-discharge", # Optional, auto-detected if not provided + url="http://localhost:8000/cds/cds-services/my-service", + workflow="encounter-discharge", # Required protocol="rest", # "rest" for CDS Hooks, "soap" for CDA timeout=10.0 ) ``` +### Workflow-Protocol Compatibility + +The client validates workflow-protocol combinations at initialization: + +| Protocol | Compatible Workflows | +|----------|---------------------| +| **REST** | `patient-view`, `encounter-discharge`, `order-select`, `order-sign` | +| **SOAP** | `sign-note-inpatient`, `sign-note-outpatient` | + + ### Loading Data === "From Registry" ```python # Load from pre-configured datasets - client.load_from_registry("mimic-on-fhir", sample_size=10) - client.load_from_registry("synthea", num_patients=5) + client.load_from_registry( + "mimic-on-fhir", + data_dir="./data/mimic-fhir", + resource_types=["MimicMedication"], + sample_size=10 + ) - # See available datasets - from healthchain.sandbox import list_available_datasets - print(list_available_datasets()) + # Available datasets: "mimic-on-fhir", "synthea-patient" ``` === "From Files" @@ -55,7 +70,7 @@ client = SandboxClient( # Load single file client.load_from_path("./data/clinical_note.xml") - # Load directory + # Load directory with glob pattern client.load_from_path("./data/cda_files/", pattern="*.xml") ``` @@ -65,38 +80,180 @@ client = SandboxClient( client.load_free_text( csv_path="./data/discharge_notes.csv", column_name="text", - workflow="encounter-discharge", + generate_synthetic=True, # Include synthetic FHIR resources random_seed=42 ) ``` +## Dataset Loaders + +HealthChain provides two pre-configured dataset loaders for testing with common FHIR testing datasets. Use `load_from_registry()` to access these datasets. + +### Overview + +| Dataset | Type | Use Case | File Format | +|---------|------|----------|-------------| +| **MIMIC-on-FHIR** | Real de-identified | Testing with realistic clinical patterns | `.ndjson.gz` per resource type | +| **Synthea** | Synthetic | Quick demos, single patient testing | `.json` Bundle per patient | + + +**When to use:** + +- **MIMIC**: Test with real-world data distributions and clinical patterns from a major hospital +- **Synthea**: Quick demos without downloading large datasets; ideal for single-patient workflows + +### MIMIC-on-FHIR Loader + +Real de-identified clinical data from Beth Israel Deaconess Medical Center in FHIR R4 format. + +**Directory Structure:** + +``` +data_dir/ +└── fhir/ + ├── MimicMedication.ndjson.gz + ├── MimicCondition.ndjson.gz + ├── MimicObservation.ndjson.gz + └── ... (other resource types) +``` + +**Usage:** + +=== "Basic" + ```python + client.load_from_registry( + "mimic-on-fhir", + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicMedication", "MimicCondition"] + ) + ``` + +=== "With Sampling" + ```python + # Load random sample for faster testing + client.load_from_registry( + "mimic-on-fhir", + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicMedication", "MimicObservation"], + sample_size=5, # 5 resources per type + random_seed=42 # Reproducible sampling + ) + ``` + +**Available Resource Types:** + +`MimicMedication`, `MimicCondition`, `MimicObservation`, `MimicProcedure`, `MimicEncounter`, `MimicPatient`, and more. Check your dataset's `/fhir` directory for available types. + +!!! note "Setup Requirements" + The full MIMIC-on-FHIR dataset requires credentialed PhysioNet access, but you can download the [demo dataset without credentials](https://physionet.org/content/mimic-iv-fhir-demo/2.1.0/) (100 patients). + +### Synthea Loader + +Synthetic patient data generated by Synthea, containing realistic FHIR Bundles (typically 100-500 resources per patient). + +**Directory Structure:** + +``` +data_dir/ +├── FirstName123_LastName456_uuid.json +├── FirstName789_LastName012_uuid.json +└── ... (one .json file per patient) +``` + +**Usage:** + +=== "First Patient (Quick Demo)" + ```python + # Automatically loads first .json file found + client.load_from_registry( + "synthea-patient", + data_dir="./synthea_sample_data_fhir_latest" + resource_type=["Condition"], # Finds all Condition resources, loads all if not specified + ) + ``` + +=== "By Patient ID" + ```python + client.load_from_registry( + "synthea-patient", + data_dir="./synthea_sample_data_fhir_latest", + patient_id="a969c177-a995-7b89-7b6d-885214dfa253", + resource_type=["Condition"], + ) + ``` + +=== "With Resource Filtering" + ```python + # Load specific resource types with sampling + client.load_from_registry( + "synthea-patient", + data_dir="./synthea_sample_data_fhir_latest", + resource_types=["Condition", "MedicationRequest", "Observation"], + sample_size=5, # 5 resources per type + random_seed=42, + ) + ``` + + +!!! tip "Getting Synthea Data" + Generate synthetic patients using [Synthea](https://github.com/synthetichealth/synthea) or [download sample data](https://synthea.mitre.org/downloads) from their releases. Each patient Bundle is self-contained with all clinical history. + +### Managing Requests + +```python +# Preview queued requests before sending +previews = client.preview_requests(limit=3) +for preview in previews: + print(f"Request {preview['index']}: {preview['type']}") + +# Get full request data for inspection +requests_dict = client.get_request_data(format="dict") +requests_json = client.get_request_data(format="json") +requests_raw = client.get_request_data(format="raw") + +# Clear queued requests to start fresh +client.clear_requests() +client.load_from_path("./different_data.xml") +``` + ### Sending Requests ```python # Send all queued requests responses = client.send_requests() -# Save results +# Save results to disk client.save_results("./output/") -# Get status +# Get client status status = client.get_status() print(status) +# { +# "sandbox_id": "...", +# "url": "http://localhost:8000/cds/...", +# "protocol": "rest", +# "workflow": "encounter-discharge", +# "requests_queued": 5, +# "responses_received": 5 +# } ``` -## Available Testing Scenarios - -**CDS Hooks** (REST protocol): - -- `workflow`: "patient-view", "encounter-discharge", "order-select", etc. -- Load FHIR Prefetch data -- Test clinical decision support services +### Using Context Manager -**Clinical Documentation** (SOAP protocol): +For automatic result saving on successful completion: -- `workflow`: "sign-note-inpatient", "sign-note-outpatient" -- Load CDA XML documents -- Test SOAP/CDA document processing +```python +with SandboxClient( + url="http://localhost:8000/cds/cds-services/my-service", + workflow="encounter-discharge" +) as client: + client.load_free_text( + csv_path="./data/notes.csv", + column_name="text" + ) + responses = client.send_requests() + # Results automatically saved to ./output/ on successful exit +``` ## Complete Examples @@ -106,14 +263,17 @@ print(status) # Initialize for CDS Hooks client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/discharge-summarizer", - workflow="encounter-discharge", - protocol="rest" + url="http://localhost:8000/cds/cds-services/sepsis-alert", + workflow="patient-view" ) # Load and send - client.load_from_registry("synthea", num_patients=3) + client.load_from_registry( + "mimic-on-fhir", + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicConditionED", "MimicObservation"], + sample_size=10, + responses = client.send_requests() client.save_results("./output/") ``` @@ -124,24 +284,46 @@ print(status) # Initialize for SOAP/CDA client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/notereader/fhir/", + url="http://localhost:8000/notereader/ProcessDocument/", workflow="sign-note-inpatient", protocol="soap" ) - # Load CDA files + # Load CDA files from directory client.load_from_path("./data/cda_files/", pattern="*.xml") responses = client.send_requests() client.save_results("./output/") ``` +=== "Free Text CSV" + ```python + from healthchain.sandbox import SandboxClient + + # Initialize client + client = SandboxClient( + url="http://localhost:8000/cds/cds-services/my-service", + workflow="patient-view" + ) + + # Load text data + client.load_free_text( + csv_path="./data/clinical_notes.csv", + column_name="note_text", + generate_synthetic=True + ) + + # Send and save + responses = client.send_requests() + client.save_results("./output/") + ``` + ## Migration Guide !!! warning "Decorator Pattern Deprecated" The `@hc.sandbox` and `@hc.ehr` decorators with `ClinicalDecisionSupport` and `ClinicalDocumentation` base classes are deprecated. Use `SandboxClient` instead. **Before:** + ```python @hc.sandbox class TestCDS(ClinicalDecisionSupport): @@ -151,18 +333,17 @@ class TestCDS(ClinicalDecisionSupport): ``` **After:** + ```python client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/my-service", + url="http://localhost:8000/cds/cds-services/my-service", workflow="patient-view" ) -client.load_from_registry("synthea", num_patients=5) +client.load_from_registry( + "synthea-patient", + data_dir="./data/synthea", + resource_types=["Condition", "Observation"], + sample_size=10 +) responses = client.send_requests() ``` - -## Next Steps - -1. **Testing**: Use `SandboxClient` for local development and testing -2. **Production**: Migrate to [HealthChainAPI Gateway](../gateway/gateway.md) -3. **Protocols**: See [CDS Hooks](../gateway/cdshooks.md) and [SOAP/CDA](../gateway/soap_cda.md) diff --git a/healthchain/fhir/__init__.py b/healthchain/fhir/__init__.py index 96a85f40..78022bf5 100644 --- a/healthchain/fhir/__init__.py +++ b/healthchain/fhir/__init__.py @@ -11,6 +11,7 @@ create_document_reference, create_single_attachment, create_resource_from_dict, + convert_prefetch_to_fhir_objects, add_provenance_metadata, add_coding_to_codeable_concept, ) @@ -22,6 +23,7 @@ set_resources, merge_bundles, extract_resources, + count_resources, ) __all__ = [ @@ -36,6 +38,7 @@ "create_document_reference", "create_single_attachment", "create_resource_from_dict", + "convert_prefetch_to_fhir_objects", # Resource modification "add_provenance_metadata", "add_coding_to_codeable_concept", @@ -46,4 +49,5 @@ "set_resources", "merge_bundles", "extract_resources", + "count_resources", ] diff --git a/healthchain/fhir/bundle_helpers.py b/healthchain/fhir/bundle_helpers.py index 14ac08e4..3edbc7ed 100644 --- a/healthchain/fhir/bundle_helpers.py +++ b/healthchain/fhir/bundle_helpers.py @@ -256,3 +256,35 @@ def extract_resources( bundle.entry = remaining_entries return extracted + + +def count_resources(bundle: Bundle) -> dict[str, int]: + """Count resources by type in a bundle. + + Args: + bundle: The FHIR Bundle to analyze + + Returns: + Dictionary mapping resource type names to their counts. + Example: {"Condition": 2, "MedicationStatement": 1, "Patient": 1} + + Example: + >>> bundle = create_bundle() + >>> add_resource(bundle, create_condition(...)) + >>> add_resource(bundle, create_condition(...)) + >>> add_resource(bundle, create_medication_statement(...)) + >>> counts = count_resources(bundle) + >>> print(counts) + {'Condition': 2, 'MedicationStatement': 1} + """ + if not bundle or not bundle.entry: + return {} + + counts: dict[str, int] = {} + for entry in bundle.entry: + if entry.resource: + # Get the resource type from the class name + resource_type = entry.resource.__resource_type__ + counts[resource_type] = counts.get(resource_type, 0) + 1 + + return counts diff --git a/healthchain/fhir/helpers.py b/healthchain/fhir/helpers.py index 20f8c106..d89ec14b 100644 --- a/healthchain/fhir/helpers.py +++ b/healthchain/fhir/helpers.py @@ -61,6 +61,62 @@ def create_resource_from_dict( return None +def convert_prefetch_to_fhir_objects( + prefetch_dict: Dict[str, Any], +) -> Dict[str, Resource]: + """Convert a dictionary of FHIR resource dicts to FHIR Resource objects. + + Takes a prefetch dictionary where values may be either dict representations of FHIR + resources or already instantiated FHIR Resource objects, and ensures all values are + FHIR Resource objects. + + Args: + prefetch_dict: Dictionary mapping keys to FHIR resource dicts or objects + + Returns: + Dict[str, Resource]: Dictionary with same keys but all values as FHIR Resource objects + + Example: + >>> prefetch = { + ... "patient": {"resourceType": "Patient", "id": "123"}, + ... "condition": Condition(id="456", ...) + ... } + >>> fhir_objects = convert_prefetch_to_fhir_objects(prefetch) + >>> isinstance(fhir_objects["patient"], Patient) # True + >>> isinstance(fhir_objects["condition"], Condition) # True + """ + from fhir.resources import get_fhir_model_class + + result: Dict[str, Resource] = {} + + for key, resource_data in prefetch_dict.items(): + if isinstance(resource_data, dict): + # Convert dict to FHIR Resource object + resource_type = resource_data.get("resourceType") + if resource_type: + try: + resource_class = get_fhir_model_class(resource_type) + result[key] = resource_class(**resource_data) + except Exception as e: + logger.warning( + f"Failed to convert {resource_type} to FHIR object: {e}" + ) + result[key] = resource_data + else: + logger.warning( + f"No resourceType found for key '{key}', keeping as dict" + ) + result[key] = resource_data + elif isinstance(resource_data, Resource): + # Already a FHIR object + result[key] = resource_data + else: + logger.warning(f"Unexpected type for key '{key}': {type(resource_data)}") + result[key] = resource_data + + return result + + def create_single_codeable_concept( code: str, display: Optional[str] = None, diff --git a/healthchain/gateway/fhir/aio.py b/healthchain/gateway/fhir/aio.py index de88d6c3..44c849bb 100644 --- a/healthchain/gateway/fhir/aio.py +++ b/healthchain/gateway/fhir/aio.py @@ -216,19 +216,23 @@ async def search( client_kwargs={"params": params}, ) - # Handle pagination if requested + # Handle pagination if requested if follow_pagination: all_entries = bundle.entry or [] page_count = 1 while bundle.link: - next_link = next((link for link in bundle.link if link.relation == "next"), None) + next_link = next( + (link for link in bundle.link if link.relation == "next"), None + ) if not next_link or (max_pages and page_count >= max_pages): break # Extract the relative URL from the next link - next_url = next_link.url.split("/")[-2:] # Get resource_type/_search part - next_params = dict(pair.split("=") for pair in next_link.url.split("?")[1].split("&")) + # next_url = next_link.url.split("/")[-2:] # Get resource_type/_search part + next_params = dict( + pair.split("=") for pair in next_link.url.split("?")[1].split("&") + ) bundle = await self._execute_with_client( "search", diff --git a/healthchain/gateway/fhir/sync.py b/healthchain/gateway/fhir/sync.py index ab0ebdb6..6a2bf464 100644 --- a/healthchain/gateway/fhir/sync.py +++ b/healthchain/gateway/fhir/sync.py @@ -280,13 +280,17 @@ def search( page_count = 1 while bundle.link: - next_link = next((link for link in bundle.link if link.relation == "next"), None) + next_link = next( + (link for link in bundle.link if link.relation == "next"), None + ) if not next_link or (max_pages and page_count >= max_pages): break # Extract the relative URL from the next link - next_url = next_link.url.split("/")[-2:] # Get resource_type/_search part - next_params = dict(pair.split("=") for pair in next_link.url.split("?")[1].split("&")) + # next_url = next_link.url.split("/")[-2:] # Get resource_type/_search part + next_params = dict( + pair.split("=") for pair in next_link.url.split("?")[1].split("&") + ) bundle = self._execute_with_client( "search", diff --git a/healthchain/io/adapters/cdsfhiradapter.py b/healthchain/io/adapters/cdsfhiradapter.py index 882071a3..7d3be0e7 100644 --- a/healthchain/io/adapters/cdsfhiradapter.py +++ b/healthchain/io/adapters/cdsfhiradapter.py @@ -7,8 +7,7 @@ from healthchain.io.base import BaseAdapter from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdsresponse import CDSResponse -from healthchain.fhir import read_content_attachment -from healthchain.models.hooks.prefetch import Prefetch +from healthchain.fhir import read_content_attachment, convert_prefetch_to_fhir_objects log = logging.getLogger(__name__) @@ -69,7 +68,6 @@ def parse( Raises: ValueError: If neither prefetch nor fhirServer is provided in cds_request - ValueError: If the prefetch data is invalid or cannot be processed NotImplementedError: If fhirServer is provided (FHIR server support not implemented) """ if cds_request.prefetch is None and cds_request.fhirServer is None: @@ -83,14 +81,13 @@ def parse( # Create an empty Document object doc = Document(data="") - # Validate the prefetch data - validated_prefetch = Prefetch(prefetch=cds_request.prefetch) - - # Set the prefetch resources - doc.fhir.prefetch_resources = validated_prefetch.prefetch + # Convert prefetch dict resources to FHIR objects + doc.fhir.prefetch_resources = convert_prefetch_to_fhir_objects( + cds_request.prefetch or {} + ) # Extract text content from DocumentReference resource if provided - document_resource = validated_prefetch.prefetch.get(prefetch_document_key) + document_resource = doc.fhir.prefetch_resources.get(prefetch_document_key) if not document_resource: log.warning( diff --git a/healthchain/models/__init__.py b/healthchain/models/__init__.py index 8b8caba2..13de5201 100644 --- a/healthchain/models/__init__.py +++ b/healthchain/models/__init__.py @@ -13,7 +13,6 @@ CDSServiceInformation, CdaResponse, ) -from .hooks import Prefetch __all__ = [ "CDSRequest", @@ -31,5 +30,4 @@ "CDSResponse", "CdaRequest", "CdaResponse", - "Prefetch", ] diff --git a/healthchain/models/hooks/__init__.py b/healthchain/models/hooks/__init__.py index e19b9e2b..62a0edcd 100644 --- a/healthchain/models/hooks/__init__.py +++ b/healthchain/models/hooks/__init__.py @@ -2,7 +2,6 @@ from .encounterdischarge import EncounterDischargeContext from .orderselect import OrderSelectContext from .ordersign import OrderSignContext -from .prefetch import Prefetch __all__ = [ @@ -10,5 +9,4 @@ "EncounterDischargeContext", "OrderSelectContext", "OrderSignContext", - "Prefetch", ] diff --git a/healthchain/models/hooks/prefetch.py b/healthchain/models/hooks/prefetch.py deleted file mode 100644 index 085c1678..00000000 --- a/healthchain/models/hooks/prefetch.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Dict, Any -from pydantic import BaseModel, field_validator -from fhir.resources.resource import Resource -from fhir.resources import get_fhir_model_class - - -class Prefetch(BaseModel): - prefetch: Dict[str, Any] - - @field_validator("prefetch") - @classmethod - def validate_fhir_resources(cls, v: Dict[str, Any]) -> Dict[str, Resource]: - if not v: - return v - - validated = {} - for key, resource_dict in v.items(): - if not isinstance(resource_dict, dict): - continue - - resource_type = resource_dict.get("resourceType") - if not resource_type: - continue - - try: - # Get the appropriate FHIR resource class - resource_class = get_fhir_model_class(resource_type) - # Convert the dict to a FHIR resource - validated[key] = resource_class.model_validate(resource_dict) - except Exception as e: - raise ValueError(f"Failed to validate FHIR resource {key}: {str(e)}") - - return validated diff --git a/healthchain/models/requests/cdsrequest.py b/healthchain/models/requests/cdsrequest.py index 99e08004..cbcf5da9 100644 --- a/healthchain/models/requests/cdsrequest.py +++ b/healthchain/models/requests/cdsrequest.py @@ -40,9 +40,7 @@ class CDSRequest(BaseModel): fhirAuthorization: Optional[FHIRAuthorization] = ( None # TODO: note this is required if fhirserver is given ) - prefetch: Optional[Dict[str, Any]] = ( - None # fhir resource is passed either thru prefetched template of fhir server - ) + prefetch: Optional[Dict[str, Any]] = None extension: Optional[List[Dict[str, Any]]] = None def model_dump(self, **kwargs): diff --git a/healthchain/sandbox/__init__.py b/healthchain/sandbox/__init__.py index 5c1c708c..af22abc3 100644 --- a/healthchain/sandbox/__init__.py +++ b/healthchain/sandbox/__init__.py @@ -3,8 +3,8 @@ from .sandboxclient import SandboxClient from .datasets import DatasetRegistry, DatasetLoader, list_available_datasets - # Import loaders to trigger auto-registration +from . import loaders # noqa: F401 __all__ = [ "SandboxClient", diff --git a/healthchain/sandbox/base.py b/healthchain/sandbox/base.py index 5db7a742..47e1d1d2 100644 --- a/healthchain/sandbox/base.py +++ b/healthchain/sandbox/base.py @@ -2,7 +2,6 @@ from typing import Dict from enum import Enum -from healthchain.models.hooks.prefetch import Prefetch from healthchain.sandbox.workflows import Workflow @@ -36,20 +35,20 @@ class DatasetLoader(ABC): """ Abstract base class for dataset loaders. - Subclasses should implement the load() method to return Prefetch data + Subclasses should implement the load() method to return data from their specific dataset source. """ @abstractmethod - def load(self, **kwargs) -> Prefetch: + def load(self, **kwargs) -> Dict: """ - Load dataset and return as Prefetch object. + Load dataset and return as dict of FHIR resources. Args: **kwargs: Loader-specific parameters Returns: - Prefetch object containing FHIR resources + Dict containing FHIR resources Raises: FileNotFoundError: If dataset files are not found diff --git a/healthchain/sandbox/datasets.py b/healthchain/sandbox/datasets.py index 336cacf7..44ff9f4d 100644 --- a/healthchain/sandbox/datasets.py +++ b/healthchain/sandbox/datasets.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List -from healthchain.models import Prefetch from healthchain.sandbox.base import DatasetLoader @@ -43,16 +42,17 @@ def register(cls, loader: DatasetLoader) -> None: log.debug(f"Registered dataset: {name}") @classmethod - def load(cls, name: str, **kwargs) -> Prefetch: + def load(cls, name: str, data_dir: str, **kwargs) -> Dict: """ Load a dataset by name. Args: name: Name of the dataset to load + data_dir: Path to the directory containing dataset files **kwargs: Dataset-specific parameters Returns: - Prefetch object containing FHIR resources + Dict containing FHIR resources Raises: KeyError: If dataset name is not registered @@ -65,7 +65,7 @@ def load(cls, name: str, **kwargs) -> Prefetch: loader = cls._datasets[name] log.info(f"Loading dataset: {name}") - return loader.load(**kwargs) + return loader.load(data_dir=data_dir, **kwargs) @classmethod def list_datasets(cls) -> List[str]: diff --git a/healthchain/sandbox/generators/cdsdatagenerator.py b/healthchain/sandbox/generators/cdsdatagenerator.py index f9d9742f..2837c315 100644 --- a/healthchain/sandbox/generators/cdsdatagenerator.py +++ b/healthchain/sandbox/generators/cdsdatagenerator.py @@ -8,7 +8,6 @@ from fhir.resources.resource import Resource from healthchain.sandbox.generators.basegenerators import generator_registry -from healthchain.models import Prefetch from healthchain.fhir import create_document_reference from healthchain.sandbox.workflows import Workflow @@ -19,24 +18,21 @@ # TODO: generate test context - move from hook models class CdsDataGenerator: """ - A class to generate CDS (Clinical Decision Support) data based on specified workflows and constraints. + Generates synthetic CDS (Clinical Decision Support) data for specified workflows. - This class provides functionality to generate synthetic FHIR resources for testing CDS systems. - It uses registered data generators to create resources like Patients, Encounters, Conditions etc. - based on configured workflows. It can also incorporate free text data from CSV files. + Uses registered generators to create FHIR resources (e.g., Patient, Encounter, Condition) according to workflow configuration. + Can optionally include free text data from a CSV file as DocumentReference. Attributes: - registry (dict): A registry mapping generator names to generator classes. - mappings (dict): A mapping of workflow names to lists of required generators. - generated_data (Dict[str, Resource]): The most recently generated FHIR resources. - workflow (str): The currently active workflow. + registry (dict): Maps generator names to classes. + mappings (dict): Maps workflows to required generators. + generated_data (Dict[str, Resource]): Most recently generated resources. + workflow (str): Currently active workflow. Example: >>> generator = CdsDataGenerator() >>> generator.set_workflow("encounter_discharge") - >>> data = generator.generate_prefetch( - ... random_seed=42 - ... ) + >>> data = generator.generate_prefetch(random_seed=42) """ # TODO: Add ordering and logic so that patient/encounter IDs are passed to subsequent generators @@ -63,27 +59,25 @@ def __init__(self): def fetch_generator(self, generator_name: str) -> Callable: """ - Fetches a data generator class by its name from the registry. + Return the generator class by name from the registry. Args: - generator_name (str): The name of the data generator to fetch (e.g. "PatientGenerator", "EncounterGenerator") + generator_name (str): Name of the data generator. Returns: - Callable: The data generator class that can be used to generate FHIR resources. Returns None if generator not found. + Callable: Generator class, or None if not found. Example: - >>> generator = CdsDataGenerator() - >>> patient_gen = generator.fetch_generator("PatientGenerator") - >>> patient = patient_gen.generate() + >>> gen = CdsDataGenerator().fetch_generator("PatientGenerator") """ return self.registry.get(generator_name) def set_workflow(self, workflow: str) -> None: """ - Sets the current workflow to be used for data generation. + Set the current workflow name to use for data generation. - Parameters: - workflow (str): The name of the workflow to set. + Args: + workflow (str): Workflow name. """ self.workflow = workflow @@ -93,48 +87,38 @@ def generate_prefetch( free_text_path: Optional[str] = None, column_name: Optional[str] = None, random_seed: Optional[int] = None, - ) -> Prefetch: + generate_resources: bool = True, + ) -> Dict[str, Resource]: """ - Generates CDS data based on the current workflow, constraints, and optional free text data. - - This method generates FHIR resources according to the configured workflow mapping. For each - resource type in the workflow, it uses the corresponding generator to create a FHIR resource. - If free text data is provided via CSV, it will also generate a DocumentReference containing - randomly selected text from the CSV. + Generate prefetch FHIR resources and/or DocumentReference. Args: - constraints (Optional[list]): A list of constraints to apply to the data generation. - Each constraint should match the format expected by the individual generators. - free_text_path (Optional[str]): Path to a CSV file containing free text data to be - included as DocumentReferences. If provided, column_name must also be specified. - column_name (Optional[str]): The name of the column in the CSV file containing the - free text data to use. Required if free_text_path is provided. - random_seed (Optional[int]): Seed value for random number generation to ensure - reproducible results. If not provided, generation will be truly random. + constraints (Optional[list]): Constraints for resource generation. + free_text_path (Optional[str]): CSV file containing free text. + column_name (Optional[str]): CSV column for free text. + random_seed (Optional[int]): Random seed. + generate_resources (bool): If True, generate synthetic FHIR resources. Returns: - Prefetch: A dictionary mapping resource types to generated FHIR resources. - The keys are lowercase resource type names (e.g. "patient", "encounter"). - If free text is provided, includes a "document" key with a DocumentReference. + Dict[str, Resource]: Generated resources keyed by resource type (lowercase), plus "document" if a free text entry is used. Raises: - ValueError: If the configured workflow is not found in the mappings - FileNotFoundError: If the free_text_path is provided but file not found - ValueError: If free_text_path provided without column_name + ValueError: If workflow is not recognized, or column name is missing. + FileNotFoundError: If free_text_path does not exist. """ - prefetch = Prefetch(prefetch={}) + prefetch = {} - if self.workflow not in self.mappings.keys(): - raise ValueError(f"Workflow {self.workflow} not found in mappings") + if generate_resources: + if self.workflow not in self.mappings: + raise ValueError(f"Workflow {self.workflow} not found in mappings") - for resource in self.mappings[self.workflow]: - generator_name = resource["generator"] - generator = self.fetch_generator(generator_name) - resource = generator.generate( - constraints=constraints, random_seed=random_seed - ) - - prefetch.prefetch[resource.__resource_type__.lower()] = resource + for resource in self.mappings[self.workflow]: + generator_name = resource["generator"] + generator = self.fetch_generator(generator_name) + resource = generator.generate( + constraints=constraints, random_seed=random_seed + ) + prefetch[resource.__resource_type__.lower()] = resource parsed_free_text = ( self.free_text_parser(free_text_path, column_name) @@ -142,7 +126,7 @@ def generate_prefetch( else None ) if parsed_free_text: - prefetch.prefetch["document"] = create_document_reference( + prefetch["document"] = create_document_reference( data=random.choice(parsed_free_text), content_type="text/plain", status="current", @@ -156,26 +140,21 @@ def generate_prefetch( def free_text_parser(self, path_to_csv: str, column_name: str) -> List[str]: """ - Parse free text data from a CSV file. - - This method reads a CSV file and extracts text data from a specified column. The text data - can later be used to create DocumentReference resources. + Read a column of free text from a CSV file. Args: - path_to_csv (str): Path to the CSV file containing the free text data. - column_name (str): Name of the column in the CSV file to extract text from. + path_to_csv (str): Path to CSV file. + column_name (str): Column name to extract. Returns: - List[str]: List of text strings extracted from the specified column. + List[str]: Extracted text values. Raises: - FileNotFoundError: If the specified CSV file does not exist or is not a file. + FileNotFoundError: If CSV file does not exist. ValueError: If column_name is not provided. - Exception: If any other error occurs while reading/parsing the CSV file. """ text_data = [] - # Check that path_to_csv is a valid path with pathlib path = Path(path_to_csv) if not path.is_file(): raise FileNotFoundError( diff --git a/healthchain/sandbox/loaders/__init__.py b/healthchain/sandbox/loaders/__init__.py new file mode 100644 index 00000000..da09d135 --- /dev/null +++ b/healthchain/sandbox/loaders/__init__.py @@ -0,0 +1,15 @@ +""" +Dataset loaders package. + +Auto-registers all available dataset loaders on import. +""" + +from healthchain.sandbox.datasets import DatasetRegistry +from healthchain.sandbox.loaders.mimic import MimicOnFHIRLoader +from healthchain.sandbox.loaders.synthea import SyntheaFHIRPatientLoader + +# Register loaders +DatasetRegistry.register(MimicOnFHIRLoader()) +DatasetRegistry.register(SyntheaFHIRPatientLoader()) + +__all__ = ["MimicOnFHIRLoader", "SyntheaFHIRPatientLoader"] diff --git a/healthchain/sandbox/loaders/mimic.py b/healthchain/sandbox/loaders/mimic.py new file mode 100644 index 00000000..be79adc3 --- /dev/null +++ b/healthchain/sandbox/loaders/mimic.py @@ -0,0 +1,234 @@ +""" +MIMIC-on-FHIR dataset loader. + +Loads patient data from the MIMIC-IV-on-FHIR dataset for testing and demos. +""" + +import logging +import random +from pathlib import Path +from typing import Dict, List, Optional + +from fhir.resources.R4B.bundle import Bundle + +from healthchain.sandbox.datasets import DatasetLoader + +log = logging.getLogger(__name__) + + +class MimicOnFHIRLoader(DatasetLoader): + """ + Loader for MIMIC-IV-on-FHIR dataset. + + This loader supports loading FHIR resources from the MIMIC-IV dataset + that has been converted to FHIR format. It can load specific patients, + sample random patients, or filter by resource types. + + Examples: + Load via SandboxClient: + >>> client = SandboxClient(...) + >>> client.load_from_registry( + ... "mimic-on-fhir", + ... data_dir="./data/mimic-fhir", + ... resource_types=["MimicMedication", "MimicCondition"], + ... sample_size=10 + ... ) + """ + + @property + def name(self) -> str: + """Dataset name for registration.""" + return "mimic-on-fhir" + + @property + def description(self) -> str: + """Dataset description.""" + return ( + "MIMIC-IV-on-FHIR: Real de-identified clinical data from " + "Beth Israel Deaconess Medical Center in FHIR format" + ) + + def load( + self, + data_dir: str, + resource_types: Optional[List[str]] = None, + sample_size: Optional[int] = None, + random_seed: Optional[int] = None, + **kwargs, + ) -> Dict: + """ + Load MIMIC-on-FHIR data as a dict of FHIR Bundles. + + Args: + data_dir: Path to root MIMIC-on-FHIR directory (expects a /fhir subdir with .ndjson.gz files) + resource_types: Resource type names to load (e.g., ["MimicMedication"]). Required. + sample_size: Number of resources to randomly sample per type (loads all if None) + random_seed: Seed for sampling + **kwargs: Reserved for future use + + Returns: + Dict mapping resource type (e.g., "MedicationStatement") to FHIR R4B Bundle + + Raises: + FileNotFoundError: If directory or resource files not found + ValueError: If resource_types is None/empty or resources fail validation + + Example: + >>> loader = MimicOnFHIRLoader() + >>> loader.load(data_dir="./data/mimic-iv-fhir", resource_types=["MimicMedication"], sample_size=100) + """ + + data_dir = Path(data_dir) + if not data_dir.exists(): + raise FileNotFoundError( + f"MIMIC-on-FHIR data directory not found at: {data_dir}\n" + f"Please ensure the directory exists and contains a 'fhir' subdirectory with .ndjson.gz files.\n" + f"Expected structure: {data_dir}/fhir/MimicMedication.ndjson.gz, etc." + ) + + # Check if /fhir subdirectory exists + fhir_dir = data_dir / "fhir" + if not fhir_dir.exists(): + raise FileNotFoundError( + f"MIMIC-on-FHIR 'fhir' subdirectory not found at: {fhir_dir}\n" + f"The loader expects data_dir to contain a 'fhir' subdirectory with .ndjson.gz resource files.\n" + f"Expected structure:\n" + f" {data_dir}/\n" + f" └── fhir/\n" + f" ├── MimicMedication.ndjson.gz\n" + f" ├── MimicCondition.ndjson.gz\n" + f" └── ... (other resource files)" + ) + + if not resource_types: + raise ValueError( + "resource_types parameter is required. " + "Provide a list of MIMIC resource types to load (e.g., ['MimicMedication', 'MimicCondition'])." + ) + + # Set random seed if provided + if random_seed is not None: + random.seed(random_seed) + + # Load resources and group by FHIR resource type + resources_by_type: Dict[str, List[Dict]] = {} + + for resource_type in resource_types: + try: + resources = self._load_resource_file( + data_dir, resource_type, sample_size + ) + + # Group by FHIR resourceType (not filename) + for resource in resources: + fhir_type = resource.get("resourceType") + if fhir_type not in resources_by_type: + resources_by_type[fhir_type] = [] + resources_by_type[fhir_type].append(resource) + + log.info( + f"Loaded {len(resources)} resources from {resource_type}.ndjson.gz" + ) + except FileNotFoundError as e: + log.error(f"Failed to load {resource_type}: {e}") + raise + except Exception as e: + log.error(f"Error loading {resource_type}: {e}") + raise ValueError(f"Failed to load {resource_type}: {e}") + + if not resources_by_type: + raise ValueError( + f"No valid resources loaded from specified resource types: {resource_types}" + ) + + bundles = {} + for fhir_type, resources in resources_by_type.items(): + bundles[fhir_type.lower()] = Bundle( + type="collection", + entry=[{"resource": resource} for resource in resources], + ) + + return bundles + + def _load_resource_file( + self, data_dir: Path, resource_type: str, sample_size: Optional[int] = None + ) -> List[Dict]: + """ + Load resources from a single MIMIC-on-FHIR .ndjson.gz file. + + Args: + data_dir: Path to MIMIC-on-FHIR data directory + resource_type: MIMIC resource type (e.g., "MimicMedication") + sample_size: Number of resources to randomly sample + + Returns: + List of resource dicts + + Raises: + FileNotFoundError: If the resource file doesn't exist + ValueError: If no valid resources found + """ + import gzip + import json + + # Construct file path - MIMIC-on-FHIR stores resources in /fhir subdirectory + fhir_dir = data_dir / "fhir" + file_path = fhir_dir / f"{resource_type}.ndjson.gz" + + if not file_path.exists(): + # Provide helpful error with available files + available_files = [] + if fhir_dir.exists(): + available_files = [f.stem for f in fhir_dir.glob("*.ndjson.gz")] + + error_msg = f"Resource file not found: {file_path}\n" + error_msg += ( + f"Expected MIMIC-on-FHIR file at {fhir_dir}/{resource_type}.ndjson.gz\n" + ) + + if available_files: + error_msg += f"\nAvailable resource files in {fhir_dir}:\n" + error_msg += "\n".join(f" - {f}" for f in available_files[:10]) + if len(available_files) > 10: + error_msg += f"\n ... and {len(available_files) - 10} more" + else: + error_msg += f"\nNo .ndjson.gz files found in {fhir_dir}" + + raise FileNotFoundError(error_msg) + + # Read all resources from file as dicts + resources = [] + line_num = 0 + + with gzip.open(file_path, "rt") as f: + for line in f: + line_num += 1 + try: + data = json.loads(line) + + if not data.get("resourceType"): + log.warning( + f"Skipping line {line_num} in {resource_type}.ndjson.gz: " + "No resourceType field found" + ) + continue + + resources.append(data) + + except json.JSONDecodeError as e: + log.warning( + f"Skipping malformed JSON at line {line_num} in {resource_type}.ndjson.gz: {e}" + ) + continue + + if not resources: + raise ValueError( + f"No valid resources found in {file_path}. " + "File may be empty or contain only invalid resources." + ) + + # Apply random sampling if requested + if sample_size is not None and sample_size < len(resources): + resources = random.sample(resources, sample_size) + + return resources diff --git a/healthchain/sandbox/loaders/synthea.py b/healthchain/sandbox/loaders/synthea.py new file mode 100644 index 00000000..aded1aa9 --- /dev/null +++ b/healthchain/sandbox/loaders/synthea.py @@ -0,0 +1,374 @@ +""" +Synthea dataset loader. + +Loads synthetic patient data generated by Synthea. +""" + +import json +import logging +import random +from pathlib import Path +from typing import Dict, List, Optional + +from fhir.resources.R4B.bundle import Bundle + +from healthchain.sandbox.datasets import DatasetLoader + +log = logging.getLogger(__name__) + + +class SyntheaFHIRPatientLoader(DatasetLoader): + """ + Loader for Synthea-generated FHIR patient data. + + Synthea is an open-source synthetic patient generator that produces + realistic patient records in FHIR format. This loader loads a single + patient's Bundle (typically containing 100-500 resources), which is + sufficient for quick demos and testing. + + The loader supports multiple ways to specify which patient file to load: + - By patient_id (UUID portion of filename) + - By patient_file (exact filename) + - Default: first .json file found + + Examples: + Load by patient ID: + >>> client = SandboxClient(...) + >>> client.load_from_registry( + ... "synthea-patient", + ... data_dir="./synthea_sample_data_fhir_latest", + ... patient_id="a969c177-a995-7b89-7b6d-885214dfa253", + ... resource_types=["Condition", "MedicationRequest"], + ... sample_size=5 + ... ) + + Load by filename: + >>> client.load_from_registry( + ... "synthea-patient", + ... data_dir="./synthea_sample_data_fhir_latest", + ... patient_file="Alton320_Gutkowski940_a969c177-a995-7b89-7b6d-885214dfa253.json" + ... ) + + Load first patient (quick demo): + >>> client.load_from_registry( + ... "synthea-patient", + ... data_dir="./synthea_sample_data_fhir_latest" + ... ) + """ + + @property + def name(self) -> str: + """Dataset name for registration.""" + return "synthea-patient" + + @property + def description(self) -> str: + """Dataset description.""" + return "Synthea: Synthetic FHIR patient data generated by SyntheaTM (single patient per load)" + + def load( + self, + data_dir: str, + patient_id: Optional[str] = None, + patient_file: Optional[str] = None, + resource_types: Optional[List[str]] = None, + sample_size: Optional[int] = None, + random_seed: Optional[int] = None, + **kwargs, + ) -> Dict[str, Bundle]: + """ + Load a single Synthea FHIR patient Bundle. + + Args: + data_dir: Path to Synthea FHIR output directory + patient_id: Patient UUID (the ID portion of the filename after the name) + e.g., "a969c177-a995-7b89-7b6d-885214dfa253" + patient_file: Exact filename to load + e.g., "Alton320_Gutkowski940_a969c177-a995-7b89-7b6d-885214dfa253.json" + resource_types: FHIR resource types to include (e.g., ["Condition", "MedicationRequest"]) + If None, all resource types are included + sample_size: Number of resources to randomly sample per resource type + If None, all resources of each type are included + random_seed: Random seed for reproducible sampling + **kwargs: Additional parameters (reserved for future use) + + Returns: + Dict mapping resource type to FHIR R4B Bundle in prefetch format + e.g., {"Condition": Bundle(...), "MedicationStatement": Bundle(...)} + + Raises: + FileNotFoundError: If data directory or patient file not found + ValueError: If patient file is not a valid FHIR Bundle or no resources found + + Example: + >>> loader = SyntheaFHIRPatientLoader() + >>> data = loader.load( + ... data_dir="./synthea_sample_data_fhir_latest", + ... patient_id="a969c177-a995-7b89-7b6d-885214dfa253", + ... resource_types=["Condition", "MedicationRequest"], + ... sample_size=3 + ... ) + >>> # Returns: {"Condition": Bundle(...), "MedicationRequest": Bundle(...)} + """ + data_dir = Path(data_dir) + if not data_dir.exists(): + raise FileNotFoundError( + f"Synthea data directory not found at: {data_dir}\n" + "Please provide a valid data_dir containing Synthea FHIR patient files." + ) + + # Find the patient file + patient_file_path = self._find_patient_file(data_dir, patient_id, patient_file) + + # Load and validate the Bundle + bundle_dict = self._load_bundle(patient_file_path) + + # Log patient information + self._log_patient_info(bundle_dict, patient_file_path.name) + + # Set random seed if provided + if random_seed is not None: + random.seed(random_seed) + + # Extract and group resources by type + resources_by_type = self._extract_resources(bundle_dict, resource_types) + + if not resources_by_type: + available_types = self._get_available_resource_types(bundle_dict) + if resource_types: + raise ValueError( + f"No resources found for requested types: {resource_types}\n" + f"Available resource types in this patient file: {available_types}" + ) + else: + raise ValueError( + f"No valid resources found in patient file: {patient_file_path.name}" + ) + + # Apply sampling if requested + if sample_size is not None: + resources_by_type = self._sample_resources(resources_by_type, sample_size) + + # Convert to Bundle objects + bundles = {} + for resource_type, resources in resources_by_type.items(): + bundles[resource_type.lower()] = Bundle( + type="collection", + entry=[{"resource": resource} for resource in resources], + ) + log.info( + f"Loaded {len(resources)} {resource_type} resource(s) from {patient_file_path.name}" + ) + + return bundles + + def _find_patient_file( + self, + data_dir: Path, + patient_id: Optional[str] = None, + patient_file: Optional[str] = None, + ) -> Path: + """ + Find the patient file to load based on provided parameters. + + Args: + data_dir: Directory containing patient files + patient_id: Patient UUID to search for + patient_file: Exact filename + + Returns: + Path to the patient file + + Raises: + FileNotFoundError: If no matching file is found + ValueError: If multiple files match the patient_id + """ + # Option 1: Exact filename provided + if patient_file: + file_path = data_dir / patient_file + if not file_path.exists(): + raise FileNotFoundError( + f"Patient file not found: {file_path}\n" + f"Please check that the file exists in {data_dir}" + ) + return file_path + + # Option 2: Patient ID provided - search for matching file + if patient_id: + matching_files = list(data_dir.glob(f"*{patient_id}*.json")) + if not matching_files: + # List available files for helpful error message + available_files = list(data_dir.glob("*.json")) + error_msg = f"No patient file found with ID: {patient_id}\n" + if available_files: + error_msg += f"\nAvailable patient files in {data_dir}:\n" + error_msg += "\n".join(f" - {f.name}" for f in available_files[:5]) + if len(available_files) > 5: + error_msg += f"\n ... and {len(available_files) - 5} more" + else: + error_msg += f"\nNo .json files found in {data_dir}" + raise FileNotFoundError(error_msg) + + if len(matching_files) > 1: + raise ValueError( + f"Multiple patient files found with ID '{patient_id}':\n" + + "\n".join(f" - {f.name}" for f in matching_files) + + "\nPlease use patient_file parameter to specify the exact file." + ) + return matching_files[0] + + # Option 3: Default - use first .json file + json_files = list(data_dir.glob("*.json")) + if not json_files: + raise FileNotFoundError( + f"No patient files (.json) found in {data_dir}\n" + "Please ensure the directory contains Synthea FHIR patient files." + ) + + log.info( + f"No patient_id or patient_file specified, using first file: {json_files[0].name}" + ) + return json_files[0] + + def _load_bundle(self, file_path: Path) -> Dict: + """ + Load and validate a Synthea FHIR Bundle from JSON file. + + Args: + file_path: Path to the patient Bundle JSON file + + Returns: + Bundle as dict + + Raises: + ValueError: If file is not a valid FHIR Bundle + """ + try: + with open(file_path, "r") as f: + bundle_dict = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in file {file_path.name}: {e}") + + # Validate it's a FHIR Bundle + if not isinstance(bundle_dict, dict): + raise ValueError( + f"File {file_path.name} does not contain a valid JSON object" + ) + + if bundle_dict.get("resourceType") != "Bundle": + raise ValueError( + f"File {file_path.name} is not a FHIR Bundle. " + f"Found resourceType: {bundle_dict.get('resourceType')}" + ) + + if "entry" not in bundle_dict: + raise ValueError(f"Bundle in {file_path.name} has no 'entry' field") + + return bundle_dict + + def _log_patient_info(self, bundle_dict: Dict, filename: str) -> None: + """ + Log information about the loaded patient. + + Args: + bundle_dict: Bundle dictionary + filename: Name of the patient file + """ + entries = bundle_dict.get("entry", []) + total_resources = len(entries) + + # Try to find Patient resource for additional info + patient_info = None + for entry in entries: + resource = entry.get("resource", {}) + if resource.get("resourceType") == "Patient": + patient_id = resource.get("id", "unknown") + name_list = resource.get("name", []) + if name_list: + name = name_list[0] + given = " ".join(name.get("given", [])) + family = name.get("family", "") + patient_info = f"{given} {family} (ID: {patient_id})" + break + + if patient_info: + log.info( + f"Loaded patient: {patient_info} from {filename} ({total_resources} resources)" + ) + else: + log.info(f"Loaded patient file {filename} ({total_resources} resources)") + + def _get_available_resource_types(self, bundle_dict: Dict) -> List[str]: + """ + Get list of available resource types in the bundle. + + Args: + bundle_dict: Bundle dictionary + + Returns: + List of unique resource types + """ + resource_types = set() + for entry in bundle_dict.get("entry", []): + resource = entry.get("resource", {}) + resource_type = resource.get("resourceType") + if resource_type: + resource_types.add(resource_type) + return sorted(resource_types) + + def _extract_resources( + self, bundle_dict: Dict, resource_types: Optional[List[str]] = None + ) -> Dict[str, List[Dict]]: + """ + Extract and group resources by type from the bundle. + + Args: + bundle_dict: Bundle dictionary + resource_types: Optional list of resource types to filter by + + Returns: + Dict mapping resource type to list of resource dicts + """ + resources_by_type: Dict[str, List[Dict]] = {} + + for entry in bundle_dict.get("entry", []): + resource = entry.get("resource", {}) + resource_type = resource.get("resourceType") + + if not resource_type: + log.warning("Skipping entry with no resourceType") + continue + + # Filter by resource types if specified + if resource_types and resource_type not in resource_types: + continue + + if resource_type not in resources_by_type: + resources_by_type[resource_type] = [] + resources_by_type[resource_type].append(resource) + + return resources_by_type + + def _sample_resources( + self, resources_by_type: Dict[str, List[Dict]], sample_size: int + ) -> Dict[str, List[Dict]]: + """ + Randomly sample resources from each type. + + Args: + resources_by_type: Dict mapping resource type to list of resources + sample_size: Number of resources to sample per type + + Returns: + Dict with sampled resources + """ + sampled = {} + for resource_type, resources in resources_by_type.items(): + if len(resources) <= sample_size: + sampled[resource_type] = resources + else: + sampled[resource_type] = random.sample(resources, sample_size) + log.info( + f"Sampled {sample_size} of {len(resources)} {resource_type} resources" + ) + return sampled diff --git a/healthchain/sandbox/requestconstructors.py b/healthchain/sandbox/requestconstructors.py index 600ef8c5..5b0222f6 100644 --- a/healthchain/sandbox/requestconstructors.py +++ b/healthchain/sandbox/requestconstructors.py @@ -10,8 +10,7 @@ import pkgutil import xmltodict -from typing import Dict, Optional -from fhir.resources.resource import Resource +from typing import Any, Dict, Optional from healthchain.sandbox.base import BaseRequestConstructor, ApiProtocol from healthchain.sandbox.workflows import ( @@ -27,7 +26,6 @@ OrderSignContext, PatientViewContext, EncounterDischargeContext, - Prefetch, ) @@ -59,7 +57,7 @@ def __init__(self) -> None: @validate_workflow(UseCaseMapping.ClinicalDecisionSupport) def construct_request( self, - prefetch_data: Dict[str, Resource], + prefetch_data: Dict[str, Any], workflow: Workflow, context: Optional[Dict[str, str]] = {}, ) -> CDSRequest: @@ -67,7 +65,7 @@ def construct_request( Build a CDS Hooks request including context and prefetch data. Args: - prefetch_data (Dict[str, Resource]): Dictionary mapping prefetch template names to FHIR resource objects. + prefetch_data (Dict[str, Any]): Dict containing FHIR resource objects. workflow (Workflow): The name of the CDS Hooks workflow (e.g., Workflow.patient_view). context (Optional[Dict[str, str]]): Optional context values for initializing the workflow's context model. @@ -76,30 +74,25 @@ def construct_request( Raises: ValueError: If the workflow is not supported or lacks a defined context model. - TypeError: If prefetch_data is not an instance of Prefetch. Note: Only CDS workflows supported by UseCaseMapping.ClinicalDecisionSupport are valid. - The expected prefetch_data argument is a Prefetch object encapsulating FHIR resources. # TODO: Add FhirServer support in future. """ - log.debug(f"Constructing CDS request for {workflow.value} from {prefetch_data}") + log.debug(f"Constructing CDS request for {workflow.value}") context_model = self.context_mapping.get(workflow, None) if context_model is None: raise ValueError( f"Invalid workflow {workflow.value} or workflow model not implemented." ) - if not isinstance(prefetch_data, Prefetch): - raise TypeError( - f"Prefetch data must be a Prefetch object, but got {type(prefetch_data)}" - ) + request = CDSRequest( hook=workflow.value, context=context_model(**context), - prefetch=prefetch_data.prefetch, + prefetch=prefetch_data, ) return request diff --git a/healthchain/sandbox/sandboxclient.py b/healthchain/sandbox/sandboxclient.py index 2587fcb0..db7a26e9 100644 --- a/healthchain/sandbox/sandboxclient.py +++ b/healthchain/sandbox/sandboxclient.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from healthchain.sandbox.base import ApiProtocol -from healthchain.models import CDSRequest, CDSResponse, Prefetch +from healthchain.models import CDSRequest, CDSResponse from healthchain.models.responses.cdaresponse import CdaResponse from healthchain.sandbox.workflows import Workflow from healthchain.sandbox.utils import ensure_directory_exists, save_data_to_directory @@ -31,7 +31,7 @@ class SandboxClient: Simplified client for testing healthcare services with various data sources. This class provides an intuitive interface for: - - Loading test datasets (MIMIC-on-FHIR, Synthea, CSV) + - Loading test datasets (MIMIC-on-FHIR, Synthea) - Generating synthetic FHIR data - Sending requests to healthcare services - Managing request/response lifecycle @@ -39,16 +39,14 @@ class SandboxClient: Examples: Load from dataset registry: >>> client = SandboxClient( - ... api_url="http://localhost:8000", - ... endpoint="/cds/cds-services/my-service" + ... url="http://localhost:8000/cds/cds-services/my-service" ... ) >>> client.load_from_registry("mimic-on-fhir", sample_size=10) >>> responses = client.send_requests() Load CDA file from path: >>> client = SandboxClient( - ... api_url="http://localhost:8000", - ... endpoint="/notereader/fhir/", + ... url="http://localhost:8000/notereader/fhir/", ... protocol="soap" ... ) >>> client.load_from_path("./data/clinical_note.xml") @@ -56,8 +54,7 @@ class SandboxClient: Generate data from free text: >>> client = SandboxClient( - ... api_url="http://localhost:8000", - ... endpoint="/cds/cds-services/discharge-summarizer" + ... url="http://localhost:8000/cds/cds-services/discharge-summarizer" ... ) >>> client.load_free_text( ... csv_path="./data/notes.csv", @@ -69,9 +66,8 @@ class SandboxClient: def __init__( self, - api_url: str, - endpoint: str, - workflow: Optional[Union[Workflow, str]] = None, + url: str, + workflow: Union[Workflow, str], protocol: Literal["rest", "soap"] = "rest", timeout: float = 10.0, ): @@ -79,37 +75,64 @@ def __init__( Initialize SandboxClient. Args: - api_url: Base URL of the service (e.g., "http://localhost:8000") - endpoint: Service endpoint path (e.g., "/cds/cds-services/my-service") - workflow: Optional workflow specification (auto-detected if not provided) + url: Full service URL (e.g., "http://localhost:8000/cds/cds-services/my-service") + workflow: Workflow specification (required) - determines request type and validation protocol: Communication protocol - "rest" for CDS Hooks, "soap" for CDA timeout: Request timeout in seconds Raises: - ValueError: If api_url or endpoint is invalid + ValueError: If url or workflow-protocol combination is invalid """ try: - self.api = httpx.URL(api_url) + self.url = httpx.URL(url) except Exception as e: - raise ValueError(f"Invalid API URL: {str(e)}") + raise ValueError(f"Invalid URL: {str(e)}") - self.endpoint = endpoint self.workflow = Workflow(workflow) if isinstance(workflow, str) else workflow self.protocol = ApiProtocol.soap if protocol == "soap" else ApiProtocol.rest self.timeout = timeout # Request/response management - self.request_data: List[Union[CDSRequest, Any]] = [] + self.requests: List[Union[CDSRequest, Any]] = [] self.responses: List[Dict] = [] self.sandbox_id = uuid.uuid4() - log.info( - f"Initialized SandboxClient {self.sandbox_id} for {self.api}{self.endpoint}" - ) + # Single validation point - fail fast on incompatible workflow-protocol + self._validate_workflow_protocol() + + log.info(f"Initialized SandboxClient {self.sandbox_id} for {self.url}") + + def _validate_workflow_protocol(self) -> None: + """ + Validate workflow is compatible with protocol. + + Raises: + ValueError: If workflow-protocol combination is invalid + """ + from healthchain.sandbox.workflows import UseCaseMapping + + if self.protocol == ApiProtocol.soap: + # SOAP only works with ClinicalDocumentation workflows + soap_workflows = UseCaseMapping.ClinicalDocumentation.allowed_workflows + if self.workflow.value not in soap_workflows: + raise ValueError( + f"Workflow '{self.workflow.value}' is not compatible with SOAP protocol. " + f"SOAP requires Clinical Documentation workflows: {soap_workflows}" + ) + + elif self.protocol == ApiProtocol.rest: + # REST only works with CDS workflows + rest_workflows = UseCaseMapping.ClinicalDecisionSupport.allowed_workflows + if self.workflow.value not in rest_workflows: + raise ValueError( + f"Workflow '{self.workflow.value}' is not compatible with REST protocol. " + f"REST requires CDS workflows: {rest_workflows}" + ) def load_from_registry( self, source: str, + data_dir: str, **kwargs: Any, ) -> "SandboxClient": """ @@ -120,29 +143,36 @@ def load_from_registry( Args: source: Dataset name (e.g., "mimic-on-fhir", "synthea") - **kwargs: Dataset-specific parameters (e.g., sample_size, num_patients) + data_dir: Path to the dataset directory + **kwargs: Dataset-specific parameters (e.g., resource_types, sample_size) Returns: Self for method chaining Raises: ValueError: If dataset not found in registry + FileNotFoundError: If data_dir doesn't exist Examples: - Discover available datasets: - >>> from healthchain.sandbox import list_available_datasets - >>> print(list_available_datasets()) - Load MIMIC dataset: - >>> client.load_from_registry("mimic-on-fhir", sample_size=10) + >>> client = SandboxClient( + ... url="http://localhost:8000/cds/patient-view", + ... workflow="patient-view", + ... ) + >>> client.load_from_registry( + ... "mimic-on-fhir", + ... data_dir="./data/mimic-fhir", + ... resource_types=["MimicMedication"], + ... sample_size=10 + ... ) """ from healthchain.sandbox.datasets import DatasetRegistry log.info(f"Loading dataset from registry: {source}") try: - loaded_data = DatasetRegistry.load(source, **kwargs) + loaded_data = DatasetRegistry.load(source, data_dir=data_dir, **kwargs) self._construct_request(loaded_data) - log.info(f"Loaded {source} dataset with {len(self.request_data)} requests") + log.info(f"Loaded {source} dataset with {len(self.requests)} requests") except KeyError: raise ValueError( f"Unknown dataset: {source}. " @@ -154,38 +184,25 @@ def load_from_path( self, path: Union[str, Path], pattern: Optional[str] = None, - workflow: Optional[Union[Workflow, str]] = None, ) -> "SandboxClient": """ - Load data from file system path. + Load data from a file or directory. - Supports loading single files or directories. File type is auto-detected - from extension and protocol: - - .xml files with SOAP protocol → CDA documents - - .json files with REST protocol → Pre-formatted Prefetch data + Supports single files or all matching files in a directory (with optional glob pattern). + For .xml (SOAP protocol) loads CDA; for .json (REST protocol) loads Prefetch. Args: - path: File path or directory path - pattern: Glob pattern for filtering files in directory (e.g., "*.xml") - workflow: Optional workflow override (auto-detected from protocol if not provided) + path: File or directory path. + pattern: Glob pattern for files in directory (e.g., "*.xml"). Returns: - Self for method chaining + Self. Raises: - FileNotFoundError: If path doesn't exist - ValueError: If no matching files found or unsupported file type - - Examples: - Load single CDA file: - >>> client.load_from_path("./data/clinical_note.xml") - - Load directory of CDA files: - >>> client.load_from_path("./data/cda_files/", pattern="*.xml") - - Load with explicit workflow: - >>> client.load_from_path("./data/note.xml", workflow="sign-note-inpatient") + FileNotFoundError: If path does not exist. + ValueError: If no matching files are found or if path is not file/dir. """ + path = Path(path) if not path.exists(): raise FileNotFoundError(f"Path not found: {path}") @@ -214,12 +231,7 @@ def load_from_path( if extension == ".xml": with open(file_path, "r") as f: xml_content = f.read() - workflow_enum = ( - Workflow(workflow) - if isinstance(workflow, str) - else workflow or self.workflow or Workflow.sign_note_inpatient - ) - self._construct_request(xml_content, workflow_enum) + self._construct_request(xml_content) log.info(f"Loaded CDA document from {file_path.name}") elif extension == ".json": @@ -227,34 +239,21 @@ def load_from_path( json_data = json.load(f) try: - # Validate and load as Prefetch object - prefetch_data = Prefetch(**json_data) - - workflow_enum = ( - Workflow(workflow) - if isinstance(workflow, str) - else workflow or self.workflow - ) - if not workflow_enum: - raise ValueError( - "Workflow must be specified when loading JSON Prefetch data. " - "Provide via 'workflow' parameter or set on client initialization." - ) - self._construct_request(prefetch_data, workflow_enum) - log.info(f"Loaded Prefetch data from {file_path.name}") + self._construct_request(json_data) + log.info(f"Loaded prefetch data from {file_path.name}") except Exception as e: - log.error(f"Failed to parse {file_path} as Prefetch: {e}") + log.error(f"Failed to parse {file_path} as prefetch data: {e}") raise ValueError( - f"File {file_path} is not valid Prefetch format. " - f"Expected JSON with 'prefetch' key containing FHIR resources. " + f"File {file_path} is not valid prefetch format. " + f"Expected JSON with FHIR resources. " f"Error: {e}" ) else: log.warning(f"Skipping unsupported file type: {file_path}") log.info( - f"Loaded {len(self.request_data)} requests from {len(files_to_load)} file(s)" + f"Loaded {len(self.requests)} requests from {len(files_to_load)} file(s)" ) return self @@ -262,89 +261,167 @@ def load_free_text( self, csv_path: str, column_name: str, - workflow: Union[Workflow, str], + generate_synthetic: bool = True, random_seed: Optional[int] = None, **kwargs: Any, ) -> "SandboxClient": """ - Generates a CDS prefetch from free text notes. - - Reads clinical notes from a CSV file and wraps it in FHIR DocumentReferences - in a CDS prefetch field for CDS Hooks workflows. Generates additional synthetic - FHIR resources as needed based on the specified workflow. + Load free-text notes from a CSV column and create FHIR DocumentReferences for CDS prefetch. + Optionally include synthetic FHIR resources based on the current workflow. Args: - csv_path: Path to CSV file containing clinical notes - column_name: Name of the column containing the text - workflow: CDS workflow type (e.g., "encounter-discharge", "patient-view") - random_seed: Seed for reproducible data generation - **kwargs: Additional parameters for data generation + csv_path: Path to the CSV file + column_name: Name of the text column + generate_synthetic: Whether to add synthetic FHIR resources (default: True) + random_seed: Seed for reproducible results + **kwargs: Extra parameters for data generation Returns: - Self for method chaining + self Raises: - FileNotFoundError: If CSV file doesn't exist - ValueError: If workflow is invalid or column not found - - Examples: - Generate discharge summaries: - >>> client.load_free_text( - ... csv_path="./data/discharge_notes.csv", - ... column_name="text", - ... workflow="encounter-discharge", - ... random_seed=42 - ... ) + FileNotFoundError: If the CSV file does not exist + ValueError: If the column is not found """ from .generators import CdsDataGenerator - workflow_enum = Workflow(workflow) if isinstance(workflow, str) else workflow - generator = CdsDataGenerator() - generator.set_workflow(workflow_enum) + generator.set_workflow(self.workflow) prefetch_data = generator.generate_prefetch( random_seed=random_seed, free_text_path=csv_path, column_name=column_name, + generate_resources=generate_synthetic, **kwargs, ) - self._construct_request(prefetch_data, workflow_enum) - log.info( - f"Generated {len(self.request_data)} requests from free text for workflow {workflow_enum.value}" - ) + self._construct_request(prefetch_data) + + if generate_synthetic: + log.info( + f"Generated {len(self.requests)} requests from free text with synthetic resources for workflow {self.workflow.value}" + ) + else: + log.info( + f"Generated {len(self.requests)} requests from free text only (no synthetic resources)" + ) return self - def _construct_request( - self, data: Union[Prefetch, Any], workflow: Optional[Workflow] = None - ) -> None: + def _construct_request(self, data: Union[Dict[str, Any], Any]) -> None: """ Convert data to request format and add to queue. Args: - data: Data to convert (Prefetch for CDS, DocumentReference for CDA) - workflow: Workflow to use for request construction + data: Data to convert (Dict for CDS prefetch, string for CDA) """ - workflow = workflow or self.workflow - if self.protocol == ApiProtocol.rest: - if not workflow: - raise ValueError( - "Workflow must be specified for REST/CDS Hooks requests" - ) constructor = CdsRequestConstructor() - request = constructor.construct_request(data, workflow) + request = constructor.construct_request(data, self.workflow) elif self.protocol == ApiProtocol.soap: constructor = ClinDocRequestConstructor() - request = constructor.construct_request( - data, workflow or Workflow.sign_note_inpatient - ) + request = constructor.construct_request(data, self.workflow) else: raise ValueError(f"Unsupported protocol: {self.protocol}") - self.request_data.append(request) + self.requests.append(request) + + def clear_requests(self) -> "SandboxClient": + """ + Clear all queued requests. + + Useful when you want to start fresh without creating a new client instance. + + Returns: + Self for method chaining + """ + count = len(self.requests) + self.requests.clear() + log.info(f"Cleared {count} queued request(s)") + + return self + + def preview_requests(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: + """ + Get preview of queued requests without sending them. + + Provides a summary view of what requests are queued, useful for debugging + and verifying data was loaded correctly before sending. + + Args: + limit: Maximum number of requests to preview. If None, preview all. + + Returns: + List of request summary dictionaries containing metadata + """ + requests = self.requests[:limit] if limit else self.requests + previews = [] + + for idx, req in enumerate(requests): + preview = { + "index": idx, + "type": req.__class__.__name__, + "protocol": self.protocol.value + if hasattr(self.protocol, "value") + else str(self.protocol), + } + + # Add protocol-specific info + if self.protocol == ApiProtocol.rest and hasattr(req, "hook"): + preview["hook"] = req.hook + preview["hookInstance"] = getattr(req, "hookInstance", None) + elif self.protocol == ApiProtocol.soap: + preview["has_document"] = hasattr(req, "document") + + previews.append(preview) + + return previews + + def get_request_data( + self, format: Literal["raw", "dict", "json"] = "dict" + ) -> Union[List, str]: + """ + Get raw request data for inspection. + + Allows direct access to request data for debugging or custom processing. + + Args: + format: Return format - "raw" for list of request objects, + "dict" for list of dictionaries, "json" for JSON string + + Returns: + Request data in specified format + + Raises: + ValueError: If format is not one of "raw", "dict", or "json" + + Examples: + >>> client.load_from_path("data.xml") + >>> # Get as dictionaries + >>> dicts = client.get_request_data("dict") + >>> # Get as JSON string + >>> json_str = client.get_request_data("json") + >>> print(json_str) + """ + if format == "raw": + return self.requests + elif format == "dict": + result = [] + for req in self.requests: + if hasattr(req, "model_dump"): + result.append(req.model_dump(exclude_none=True)) + elif hasattr(req, "model_dump_xml"): + result.append({"document": req.model_dump_xml()}) + else: + result.append(req) + return result + elif format == "json": + return json.dumps(self.get_request_data("dict"), indent=2) + else: + raise ValueError( + f"Invalid format '{format}'. Must be 'raw', 'dict', or 'json'" + ) def send_requests(self) -> List[Dict]: """ @@ -352,28 +429,24 @@ def send_requests(self) -> List[Dict]: Returns: List of response dictionaries - - Raises: - RuntimeError: If no requests are queued """ - if not self.request_data: + if not self.requests: raise RuntimeError( "No requests to send. Load data first using load_from_registry(), load_from_path(), or load_free_text()" ) - url = self.api.join(self.endpoint) - log.info(f"Sending {len(self.request_data)} requests to {url}") + log.info(f"Sending {len(self.requests)} requests to {self.url}") with httpx.Client(follow_redirects=True) as client: responses: List[Dict] = [] timeout = httpx.Timeout(self.timeout, read=None) - for request in self.request_data: + for request in self.requests: try: if self.protocol == ApiProtocol.soap: headers = {"Content-Type": "text/xml; charset=utf-8"} response = client.post( - url=str(url), + url=str(self.url), data=request.document, headers=headers, timeout=timeout, @@ -383,19 +456,26 @@ def send_requests(self) -> List[Dict]: responses.append(response_model.model_dump_xml()) else: # REST/CDS Hooks - log.debug(f"Making POST request to: {url}") + log.debug(f"Making POST request to: {self.url}") response = client.post( - url=str(url), + url=str(self.url), json=request.model_dump(exclude_none=True), timeout=timeout, ) response.raise_for_status() - response_data = response.json() + try: + response_data = response.json() cds_response = CDSResponse(**response_data) responses.append(cds_response.model_dump(exclude_none=True)) + except json.JSONDecodeError: + log.error( + f"Invalid JSON response from {self.url}. " + f"Response preview: {response.text[:200]}" + ) + responses.append({}) except Exception: - # Fallback to raw response if parsing fails + # Fallback to raw response if CDSResponse parsing fails responses.append(response_data) except httpx.HTTPStatusError as exc: @@ -428,9 +508,6 @@ def save_results(self, directory: Union[str, Path] = "./output/") -> None: Args: directory: Directory to save data to (default: "./output/") - - Raises: - RuntimeError: If no responses are available to save """ if not self.responses: raise RuntimeError( @@ -445,10 +522,10 @@ def save_results(self, directory: Union[str, Path] = "./output/") -> None: # Save requests if self.protocol == ApiProtocol.soap: - request_data = [request.model_dump_xml() for request in self.request_data] + request_data = [request.model_dump_xml() for request in self.requests] else: request_data = [ - request.model_dump(exclude_none=True) for request in self.request_data + request.model_dump(exclude_none=True) for request in self.requests ] save_data_to_directory( @@ -480,20 +557,36 @@ def get_status(self) -> Dict[str, Any]: """ return { "sandbox_id": str(self.sandbox_id), - "api_url": str(self.api), - "endpoint": self.endpoint, + "url": str(self.url), "protocol": self.protocol.value if hasattr(self.protocol, "value") else str(self.protocol), "workflow": self.workflow.value if self.workflow else None, - "requests_queued": len(self.request_data), + "requests_queued": len(self.requests), "responses_received": len(self.responses), } + def __enter__(self) -> "SandboxClient": + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Context manager exit - auto-save results if responses exist. + + Only saves if no exception occurred and responses were generated. + """ + if self.responses and exc_type is None: + try: + self.save_results() + log.info("Auto-saved results on context exit") + except Exception as e: + log.warning(f"Failed to auto-save results: {e}") + def __repr__(self) -> str: """String representation of SandboxClient.""" return ( - f"SandboxClient(api_url='{self.api}', endpoint='{self.endpoint}', " + f"SandboxClient(url='{self.url}', " f"protocol='{self.protocol.value if hasattr(self.protocol, 'value') else self.protocol}', " - f"requests={len(self.request_data)})" + f"requests={len(self.requests)})" ) diff --git a/scripts/healthchainapi_e2e_demo.py b/scripts/healthchainapi_e2e_demo.py index 0b6870d0..27483168 100644 --- a/scripts/healthchainapi_e2e_demo.py +++ b/scripts/healthchainapi_e2e_demo.py @@ -57,6 +57,11 @@ from healthchain.sandbox import SandboxClient from healthchain.fhir import create_document_reference +from dotenv import load_dotenv + +load_dotenv() + + # Configuration CONFIG = { "server": { @@ -409,8 +414,7 @@ def create_sandboxes(): # NoteReader Sandbox notereader_client = SandboxClient( - api_url=base_url, - endpoint="/notereader/fhir/", + url=base_url + "/notereader/fhir/", workflow=CONFIG["workflows"]["notereader"], protocol="soap", ) @@ -420,8 +424,7 @@ def create_sandboxes(): # CDS Hooks Sandbox cds_client = SandboxClient( - api_url=base_url, - endpoint="/cds/cds-services/discharge-summary", + url=base_url + "/cds/cds-services/discharge-summary", workflow=CONFIG["workflows"]["cds"], protocol="rest", ) @@ -430,7 +433,6 @@ def create_sandboxes(): cds_client.load_free_text( csv_path=CONFIG["data"]["discharge_notes_path"], column_name="text", - workflow=CONFIG["workflows"]["cds"], ) print_success("Sandbox environments created") diff --git a/tests/conftest.py b/tests/conftest.py index be60e5a2..ee9e4b47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import tempfile -from healthchain.models.hooks.prefetch import Prefetch from healthchain.models.requests.cdarequest import CdaRequest from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdaresponse import CdaResponse @@ -353,26 +352,24 @@ def test_empty_document(): @pytest.fixture def valid_prefetch_data(): - """Provides a `Prefetch` model object for CDS Hooks testing. + """Provides a dict of FHIR resources for CDS Hooks testing. Contains a single prefetch key "document" with a DocumentReference resource. Use this for testing services that consume CDS Hooks prefetch data. Example: def test_prefetch_handler(valid_prefetch_data): - request = CDSRequest(prefetch=valid_prefetch_data.prefetch) + request = CDSRequest(prefetch=valid_prefetch_data) # ... test logic Returns: - healthchain.models.hooks.prefetch.Prefetch: A Pydantic model representing valid prefetch data. + dict: A dictionary containing FHIR resources for prefetch data. """ - return Prefetch( - prefetch={ - "document": create_document_reference( - content_type="text/plain", data="Test document content" - ) - } - ) + return { + "document": create_document_reference( + content_type="text/plain", data="Test document content" + ) + } # ################################################# diff --git a/tests/fhir/test_bundle_resource_counting.py b/tests/fhir/test_bundle_resource_counting.py new file mode 100644 index 00000000..821a0d32 --- /dev/null +++ b/tests/fhir/test_bundle_resource_counting.py @@ -0,0 +1,57 @@ +"""Tests for bundle resource counting functionality.""" + +from healthchain.fhir import ( + create_bundle, + add_resource, + create_condition, + create_medication_statement, + create_allergy_intolerance, + count_resources, +) + + +def test_count_resources_with_empty_bundle(empty_bundle): + """count_resources returns empty dict for empty bundle.""" + counts = count_resources(empty_bundle) + assert counts == {} + + +def test_count_resources_with_single_resource_type(empty_bundle): + """count_resources counts single resource type correctly.""" + add_resource(empty_bundle, create_condition("Patient/1", "123", "Test")) + add_resource(empty_bundle, create_condition("Patient/1", "456", "Test 2")) + + counts = count_resources(empty_bundle) + assert counts == {"Condition": 2} + + +def test_count_resources_with_mixed_resource_types(empty_bundle): + """count_resources counts multiple resource types correctly.""" + add_resource(empty_bundle, create_condition("Patient/1", "123", "Test")) + add_resource(empty_bundle, create_condition("Patient/1", "456", "Test 2")) + add_resource(empty_bundle, create_medication_statement("Patient/1", "789", "Med")) + add_resource( + empty_bundle, create_allergy_intolerance("Patient/1", "999", "Allergy") + ) + + counts = count_resources(empty_bundle) + assert counts == { + "Condition": 2, + "MedicationStatement": 1, + "AllergyIntolerance": 1, + } + + +def test_count_resources_with_none_bundle(): + """count_resources handles None bundle gracefully.""" + counts = count_resources(None) + assert counts == {} + + +def test_count_resources_with_bundle_no_entry(): + """count_resources handles bundle with None entry.""" + bundle = create_bundle() + bundle.entry = None + + counts = count_resources(bundle) + assert counts == {} diff --git a/tests/gateway/test_fhir_gateway.py b/tests/gateway/test_fhir_gateway.py index 14cf289e..fea4207d 100644 --- a/tests/gateway/test_fhir_gateway.py +++ b/tests/gateway/test_fhir_gateway.py @@ -285,32 +285,26 @@ def test_search_with_empty_bundle(): ) assert result.entry is None + def test_search_with_pagination(fhir_gateway): """Gateway.search fetches all pages when pagination is enabled.""" # Create mock bundles for pagination page1 = Bundle( type="searchset", entry=[BundleEntry(resource=Patient(id="1"))], - link=[{"relation": "next", "url": "Patient?page=2"}] + link=[{"relation": "next", "url": "Patient?page=2"}], ) page2 = Bundle( type="searchset", entry=[BundleEntry(resource=Patient(id="2"))], - link=[{"relation": "next", "url": "Patient?page=3"}] - ) - page3 = Bundle( - type="searchset", - entry=[BundleEntry(resource=Patient(id="3"))] + link=[{"relation": "next", "url": "Patient?page=3"}], ) + page3 = Bundle(type="searchset", entry=[BundleEntry(resource=Patient(id="3"))]) with patch.object( fhir_gateway, "_execute_with_client", side_effect=[page1, page2, page3] ) as mock_execute: - result = fhir_gateway.search( - Patient, - {"name": "Smith"}, - follow_pagination=True - ) + result = fhir_gateway.search(Patient, {"name": "Smith"}, follow_pagination=True) assert mock_execute.call_count == 3 assert result.entry is not None @@ -324,22 +318,19 @@ def test_search_with_max_pages(fhir_gateway): page1 = Bundle( type="searchset", entry=[BundleEntry(resource=Patient(id="1"))], - link=[{"relation": "next", "url": "Patient?page=2"}] + link=[{"relation": "next", "url": "Patient?page=2"}], ) page2 = Bundle( type="searchset", entry=[BundleEntry(resource=Patient(id="2"))], - link=[{"relation": "next", "url": "Patient?page=3"}] + link=[{"relation": "next", "url": "Patient?page=3"}], ) with patch.object( fhir_gateway, "_execute_with_client", side_effect=[page1, page2] ) as mock_execute: result = fhir_gateway.search( - Patient, - {"name": "Smith"}, - follow_pagination=True, - max_pages=2 + Patient, {"name": "Smith"}, follow_pagination=True, max_pages=2 ) assert mock_execute.call_count == 2 @@ -354,17 +345,13 @@ def test_search_with_pagination_empty_next_link(fhir_gateway): bundle = Bundle( type="searchset", entry=[BundleEntry(resource=Patient(id="1"))], - link=[{"relation": "self", "url": "Patient?name=Smith"}] + link=[{"relation": "self", "url": "Patient?name=Smith"}], ) with patch.object( fhir_gateway, "_execute_with_client", return_value=bundle ) as mock_execute: - result = fhir_gateway.search( - Patient, - {"name": "Smith"}, - follow_pagination=True - ) + result = fhir_gateway.search(Patient, {"name": "Smith"}, follow_pagination=True) mock_execute.assert_called_once() assert result.entry is not None @@ -377,12 +364,9 @@ def test_search_with_pagination_and_provenance(fhir_gateway): page1 = Bundle( type="searchset", entry=[BundleEntry(resource=Patient(id="1"))], - link=[{"relation": "next", "url": "Patient?page=2"}] - ) - page2 = Bundle( - type="searchset", - entry=[BundleEntry(resource=Patient(id="2"))] + link=[{"relation": "next", "url": "Patient?page=2"}], ) + page2 = Bundle(type="searchset", entry=[BundleEntry(resource=Patient(id="2"))]) with patch.object( fhir_gateway, "_execute_with_client", side_effect=[page1, page2] @@ -393,7 +377,7 @@ def test_search_with_pagination_and_provenance(fhir_gateway): source="test_source", follow_pagination=True, add_provenance=True, - provenance_tag="aggregated" + provenance_tag="aggregated", ) assert mock_execute.call_count == 2 @@ -404,4 +388,4 @@ def test_search_with_pagination_and_provenance(fhir_gateway): for entry in result.entry: assert entry.resource.meta is not None assert entry.resource.meta.source == "urn:healthchain:source:test_source" - assert entry.resource.meta.tag[0].code == "aggregated" \ No newline at end of file + assert entry.resource.meta.tag[0].code == "aggregated" diff --git a/tests/gateway/test_fhir_gateway_async.py b/tests/gateway/test_fhir_gateway_async.py index f0e948f9..a64235fc 100644 --- a/tests/gateway/test_fhir_gateway_async.py +++ b/tests/gateway/test_fhir_gateway_async.py @@ -126,7 +126,6 @@ async def test_search_operation_with_parameters(fhir_gateway): assert result == mock_bundle - @pytest.mark.asyncio async def test_search_with_pagination(fhir_gateway): """AsyncFHIRGateway.search fetches all pages when pagination is enabled.""" @@ -134,25 +133,20 @@ async def test_search_with_pagination(fhir_gateway): page1 = Bundle( type="searchset", entry=[{"resource": Patient(id="1")}], - link=[{"relation": "next", "url": "Patient?page=2"}] + link=[{"relation": "next", "url": "Patient?page=2"}], ) page2 = Bundle( type="searchset", entry=[{"resource": Patient(id="2")}], - link=[{"relation": "next", "url": "Patient?page=3"}] - ) - page3 = Bundle( - type="searchset", - entry=[{"resource": Patient(id="3")}] + link=[{"relation": "next", "url": "Patient?page=3"}], ) + page3 = Bundle(type="searchset", entry=[{"resource": Patient(id="3")}]) with patch.object( fhir_gateway, "_execute_with_client", side_effect=[page1, page2, page3] ) as mock_execute: result = await fhir_gateway.search( - Patient, - {"name": "Smith"}, - follow_pagination=True + Patient, {"name": "Smith"}, follow_pagination=True ) assert mock_execute.call_count == 3 @@ -167,22 +161,19 @@ async def test_search_with_max_pages(fhir_gateway): page1 = Bundle( type="searchset", entry=[{"resource": Patient(id="1")}], - link=[{"relation": "next", "url": "Patient?page=2"}] + link=[{"relation": "next", "url": "Patient?page=2"}], ) page2 = Bundle( type="searchset", entry=[{"resource": Patient(id="2")}], - link=[{"relation": "next", "url": "Patient?page=3"}] + link=[{"relation": "next", "url": "Patient?page=3"}], ) with patch.object( fhir_gateway, "_execute_with_client", side_effect=[page1, page2] ) as mock_execute: result = await fhir_gateway.search( - Patient, - {"name": "Smith"}, - follow_pagination=True, - max_pages=2 + Patient, {"name": "Smith"}, follow_pagination=True, max_pages=2 ) assert mock_execute.call_count == 2 @@ -197,16 +188,14 @@ async def test_search_with_pagination_empty_next_link(fhir_gateway): bundle = Bundle( type="searchset", entry=[{"resource": Patient(id="1")}], - link=[{"relation": "self", "url": "Patient?name=Smith"}] + link=[{"relation": "self", "url": "Patient?name=Smith"}], ) with patch.object( fhir_gateway, "_execute_with_client", return_value=bundle ) as mock_execute: result = await fhir_gateway.search( - Patient, - {"name": "Smith"}, - follow_pagination=True + Patient, {"name": "Smith"}, follow_pagination=True ) mock_execute.assert_called_once() @@ -221,12 +210,9 @@ async def test_search_with_pagination_and_provenance(fhir_gateway): page1 = Bundle( type="searchset", entry=[{"resource": Patient(id="1")}], - link=[{"relation": "next", "url": "Patient?page=2"}] - ) - page2 = Bundle( - type="searchset", - entry=[{"resource": Patient(id="2")}] + link=[{"relation": "next", "url": "Patient?page=2"}], ) + page2 = Bundle(type="searchset", entry=[{"resource": Patient(id="2")}]) with patch.object( fhir_gateway, "_execute_with_client", side_effect=[page1, page2] @@ -237,7 +223,7 @@ async def test_search_with_pagination_and_provenance(fhir_gateway): source="test_source", follow_pagination=True, add_provenance=True, - provenance_tag="aggregated" + provenance_tag="aggregated", ) assert mock_execute.call_count == 2 @@ -250,6 +236,7 @@ async def test_search_with_pagination_and_provenance(fhir_gateway): assert entry.resource.meta.source == "urn:healthchain:source:test_source" assert entry.resource.meta.tag[0].code == "aggregated" + @pytest.mark.asyncio async def test_modify_context_for_existing_resource(fhir_gateway, test_patient): """Modify context manager fetches, yields, and updates existing resources.""" diff --git a/tests/sandbox/generators/test_cds_data_generator.py b/tests/sandbox/generators/test_cds_data_generator.py index eecea05a..1597ff67 100644 --- a/tests/sandbox/generators/test_cds_data_generator.py +++ b/tests/sandbox/generators/test_cds_data_generator.py @@ -16,14 +16,14 @@ def test_generator_orchestrator_encounter_discharge(): generator.set_workflow(workflow=workflow) generator.generate_prefetch() - assert len(generator.generated_data.prefetch) == 4 - assert generator.generated_data.prefetch["encounter"] is not None - assert isinstance(generator.generated_data.prefetch["encounter"], Encounter) - assert generator.generated_data.prefetch["condition"] is not None - assert isinstance(generator.generated_data.prefetch["condition"], Condition) - assert generator.generated_data.prefetch["procedure"] is not None - assert isinstance(generator.generated_data.prefetch["procedure"], Procedure) - assert generator.generated_data.prefetch["medicationrequest"] is not None + assert len(generator.generated_data) == 4 + assert generator.generated_data["encounter"] is not None + assert isinstance(generator.generated_data["encounter"], Encounter) + assert generator.generated_data["condition"] is not None + assert isinstance(generator.generated_data["condition"], Condition) + assert generator.generated_data["procedure"] is not None + assert isinstance(generator.generated_data["procedure"], Procedure) + assert generator.generated_data["medicationrequest"] is not None def test_generator_orchestrator_patient_view(): @@ -33,13 +33,13 @@ def test_generator_orchestrator_patient_view(): generator.set_workflow(workflow=workflow) generator.generate_prefetch() - assert len(generator.generated_data.prefetch) == 3 - assert generator.generated_data.prefetch["patient"] is not None - assert isinstance(generator.generated_data.prefetch["patient"], Patient) - assert generator.generated_data.prefetch["encounter"] is not None - assert isinstance(generator.generated_data.prefetch["encounter"], Encounter) - assert generator.generated_data.prefetch["condition"] is not None - assert isinstance(generator.generated_data.prefetch["condition"], Condition) + assert len(generator.generated_data) == 3 + assert generator.generated_data["patient"] is not None + assert isinstance(generator.generated_data["patient"], Patient) + assert generator.generated_data["encounter"] is not None + assert isinstance(generator.generated_data["encounter"], Encounter) + assert generator.generated_data["condition"] is not None + assert isinstance(generator.generated_data["condition"], Condition) @pytest.mark.skip() @@ -52,8 +52,8 @@ def test_generator_with_json(): free_text_path="use_cases/my_encounter_data.csv", column_name="free_text" ) - assert len(generator.generated_data.prefetch) == 4 - assert generator.generated_data.prefetch["patient"] is not None - assert generator.generated_data.prefetch["encounter"] is not None - assert generator.generated_data.prefetch["condition"] is not None - assert generator.generated_data.prefetch["document"] is not None + assert len(generator.generated_data) == 4 + assert generator.generated_data["patient"] is not None + assert generator.generated_data["encounter"] is not None + assert generator.generated_data["condition"] is not None + assert generator.generated_data["document"] is not None diff --git a/tests/sandbox/test_cds_sandbox.py b/tests/sandbox/test_cds_sandbox.py index 28450fb8..e997b120 100644 --- a/tests/sandbox/test_cds_sandbox.py +++ b/tests/sandbox/test_cds_sandbox.py @@ -5,7 +5,6 @@ from healthchain.gateway.api import HealthChainAPI from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdsresponse import CDSResponse, Card -from healthchain.models.hooks.prefetch import Prefetch from healthchain.fhir import create_bundle, create_condition @@ -29,20 +28,19 @@ async def handle_patient_view(request: CDSRequest) -> CDSResponse: # Create SandboxClient client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/test-patient-view", + url="http://localhost:8000/cds/cds-services/test-patient-view", workflow="patient-view", protocol="rest", ) # Load test data test_bundle = create_bundle() - prefetch_data = Prefetch(prefetch={"patient": test_bundle}) - client._construct_request(prefetch_data, client.workflow) + prefetch_data = {"patient": test_bundle} + client._construct_request(prefetch_data) # Verify request was constructed - assert len(client.request_data) == 1 - assert client.request_data[0].hook == "patient-view" + assert len(client.requests) == 1 + assert client.requests[0].hook == "patient-view" # Mock HTTP response with patch("httpx.Client") as mock_client_class: @@ -74,8 +72,7 @@ def test_cdshooks_workflows(): """Test CDSHooks sandbox with patient-view workflow""" # Create SandboxClient client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/cds/cds-services/patient-view", + url="http://localhost:8000/cds/cds-services/patient-view", workflow="patient-view", protocol="rest", ) @@ -88,11 +85,11 @@ def test_cdshooks_workflows(): patient_bundle.entry = [{"resource": condition}] # Load data into client - prefetch_data = Prefetch(prefetch={"patient": patient_bundle}) - client._construct_request(prefetch_data, client.workflow) + prefetch_data = {"patient": patient_bundle} + client._construct_request(prefetch_data) # Verify request was constructed - assert len(client.request_data) == 1 + assert len(client.requests) == 1 # Mock HTTP response with patch("httpx.Client") as mock_client_class: diff --git a/tests/sandbox/test_clindoc_sandbox.py b/tests/sandbox/test_clindoc_sandbox.py index 9d2b0fab..e1ae80d0 100644 --- a/tests/sandbox/test_clindoc_sandbox.py +++ b/tests/sandbox/test_clindoc_sandbox.py @@ -24,18 +24,17 @@ def process_document(cda_request: CdaRequest) -> CdaResponse: # Create SandboxClient for SOAP/CDA client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/notereader/fhir/", + url="http://localhost:8000/notereader/fhir/", workflow="sign-note-inpatient", protocol="soap", ) # Load test document test_document = "document" - client._construct_request(test_document, client.workflow) + client._construct_request(test_document) # Verify request was constructed - assert len(client.request_data) == 1 + assert len(client.requests) == 1 # Mock HTTP response with proper SOAP envelope structure with patch("httpx.Client") as mock_client_class: @@ -74,18 +73,17 @@ def test_notereader_sandbox_workflow_execution(): """Test executing a NoteReader workflow with SandboxClient""" # Create SandboxClient client = SandboxClient( - api_url="http://localhost:8000", - endpoint="/notereader/fhir/", + url="http://localhost:8000/notereader/fhir/", workflow="sign-note-inpatient", protocol="soap", ) # Load clinical document clinical_document = "Test content" - client._construct_request(clinical_document, client.workflow) + client._construct_request(clinical_document) # Verify request was constructed - assert len(client.request_data) == 1 + assert len(client.requests) == 1 # Mock HTTP response with proper SOAP envelope structure with patch("httpx.Client") as mock_client_class: diff --git a/tests/sandbox/test_mimic_loader.py b/tests/sandbox/test_mimic_loader.py new file mode 100644 index 00000000..0c2614e2 --- /dev/null +++ b/tests/sandbox/test_mimic_loader.py @@ -0,0 +1,318 @@ +"""Tests for MIMIC-on-FHIR dataset loader.""" + +import gzip +import json +import tempfile +from pathlib import Path + +import pytest + +from healthchain.sandbox.loaders.mimic import MimicOnFHIRLoader + + +@pytest.fixture +def temp_mimic_data_dir(): + """Create temporary MIMIC-on-FHIR data directory structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + data_path = Path(tmpdir) + fhir_dir = data_path / "fhir" + fhir_dir.mkdir() + yield data_path + + +@pytest.fixture +def mock_medication_resources(): + """Sample MedicationStatement resources for testing.""" + return [ + { + "resourceType": "MedicationStatement", + "id": "med-1", + "status": "recorded", + "medication": { + "concept": { + "coding": [ + { + "system": "http://www.nlm.nih.gov/research/umls/rxnorm", + "code": "313782", + } + ] + } + }, + "subject": {"reference": "Patient/123"}, + }, + { + "resourceType": "MedicationStatement", + "id": "med-2", + "status": "recorded", + "medication": { + "concept": { + "coding": [ + { + "system": "http://www.nlm.nih.gov/research/umls/rxnorm", + "code": "197361", + } + ] + } + }, + "subject": {"reference": "Patient/456"}, + }, + ] + + +@pytest.fixture +def mock_condition_resources(): + """Sample Condition resources for testing.""" + return [ + { + "resourceType": "Condition", + "id": "cond-1", + "clinicalStatus": { + "coding": [ + { + "system": "http://terminology.hl7.org/CodeSystem/condition-clinical", + "code": "active", + } + ] + }, + "code": { + "coding": [{"system": "http://snomed.info/sct", "code": "44054006"}] + }, + "subject": {"reference": "Patient/123"}, + } + ] + + +def create_ndjson_gz_file(file_path: Path, resources: list): + """Helper to create gzipped NDJSON file.""" + with gzip.open(file_path, "wt") as f: + for resource in resources: + f.write(json.dumps(resource) + "\n") + + +def test_mimic_loader_requires_resource_types(temp_mimic_data_dir): + """MimicOnFHIRLoader raises ValueError when resource_types is None.""" + loader = MimicOnFHIRLoader() + + with pytest.raises(ValueError, match="resource_types parameter is required"): + loader.load(data_dir=str(temp_mimic_data_dir)) + + +def test_mimic_loader_raises_error_for_missing_data_path(): + """MimicOnFHIRLoader raises FileNotFoundError when data path doesn't exist.""" + loader = MimicOnFHIRLoader() + + with pytest.raises(FileNotFoundError): + loader.load(data_dir="/nonexistent/path", resource_types=["MimicMedication"]) + + +def test_mimic_loader_raises_error_for_missing_resource_file(temp_mimic_data_dir): + """MimicOnFHIRLoader raises FileNotFoundError when resource file doesn't exist.""" + loader = MimicOnFHIRLoader() + + with pytest.raises(FileNotFoundError, match="Resource file not found"): + loader.load( + data_dir=str(temp_mimic_data_dir), resource_types=["MimicMedication"] + ) + + +def test_mimic_loader_loads_single_resource_type( + temp_mimic_data_dir, mock_medication_resources +): + """MimicOnFHIRLoader loads and validates single resource type.""" + # Create mock data file + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), resource_types=["MimicMedication"] + ) + + assert isinstance(result, dict) + assert "medicationstatement" in result + # Result dict contains a Bundle + bundle = result["medicationstatement"] + assert type(bundle).__name__ == "Bundle" + assert len(bundle.entry) == 2 + assert bundle.entry[0].resource.id == "med-1" + + +def test_mimic_loader_loads_multiple_resource_types( + temp_mimic_data_dir, mock_medication_resources, mock_condition_resources +): + """MimicOnFHIRLoader loads multiple resource types and groups by FHIR type.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + create_ndjson_gz_file( + fhir_dir / "MimicCondition.ndjson.gz", mock_condition_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication", "MimicCondition"], + ) + + assert "medicationstatement" in result + assert "condition" in result + # Each result value is a Bundle + med_bundle = result["medicationstatement"] + cond_bundle = result["condition"] + assert len(med_bundle.entry) == 2 + assert len(cond_bundle.entry) == 1 + + +@pytest.mark.parametrize("sample_size,expected_count", [(1, 1), (2, 2)]) +def test_mimic_loader_sampling_behavior( + temp_mimic_data_dir, mock_medication_resources, sample_size, expected_count +): + """MimicOnFHIRLoader samples specified number of resources.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + sample_size=sample_size, + ) + + bundle = result["medicationstatement"] + assert len(bundle.entry) == expected_count + + +def test_mimic_loader_deterministic_sampling_with_seed( + temp_mimic_data_dir, mock_medication_resources +): + """MimicOnFHIRLoader produces consistent results with random_seed.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result1 = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + sample_size=1, + random_seed=42, + ) + result2 = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + sample_size=1, + random_seed=42, + ) + + bundle1 = result1["medicationstatement"] + bundle2 = result2["medicationstatement"] + assert bundle1.entry[0].resource.id == bundle2.entry[0].resource.id + + +def test_mimic_loader_handles_malformed_json(temp_mimic_data_dir): + """MimicOnFHIRLoader skips malformed JSON lines and continues processing.""" + fhir_dir = temp_mimic_data_dir / "fhir" + file_path = fhir_dir / "MimicMedication.ndjson.gz" + + # Create file with mix of valid and malformed JSON + with gzip.open(file_path, "wt") as f: + f.write('{"invalid json\n') # Malformed + f.write( + json.dumps( + { + "resourceType": "MedicationStatement", + "id": "med-1", + "status": "recorded", + "medication": { + "concept": { + "coding": [ + { + "system": "http://www.nlm.nih.gov/research/umls/rxnorm", + "code": "313782", + } + ] + } + }, + "subject": {"reference": "Patient/123"}, + } + ) + + "\n" + ) # Valid + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), resource_types=["MimicMedication"] + ) + + # Should load the valid resource despite malformed line + bundle = result["medicationstatement"] + assert len(bundle.entry) == 1 + + +def test_mimic_loader_raises_error_for_invalid_fhir_resources(temp_mimic_data_dir): + """Loader validates FHIR resources and raises error for invalid data.""" + fhir_dir = temp_mimic_data_dir / "fhir" + file_path = fhir_dir / "MimicMedication.ndjson.gz" + + # Create file with invalid FHIR resource (missing required fields) + invalid_resources = [ + { + "resourceType": "MedicationStatement", + "id": "med-1", + }, # Missing required fields + ] + + with gzip.open(file_path, "wt") as f: + for resource in invalid_resources: + f.write(json.dumps(resource) + "\n") + + loader = MimicOnFHIRLoader() + + # FHIR validation now catches the invalid resource + with pytest.raises(Exception): + loader.load( + data_dir=str(temp_mimic_data_dir), resource_types=["MimicMedication"] + ) + + +def test_mimic_loader_skips_resources_without_resource_type(temp_mimic_data_dir): + """MimicOnFHIRLoader skips resources missing resourceType field.""" + fhir_dir = temp_mimic_data_dir / "fhir" + file_path = fhir_dir / "MimicMedication.ndjson.gz" + + resources = [ + {"id": "med-1", "status": "recorded"}, # No resourceType + { + "resourceType": "MedicationStatement", + "id": "med-2", + "status": "recorded", + "medication": { + "concept": { + "coding": [ + { + "system": "http://www.nlm.nih.gov/research/umls/rxnorm", + "code": "313782", + } + ] + } + }, + "subject": {"reference": "Patient/123"}, + }, + ] + + create_ndjson_gz_file(file_path, resources) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), resource_types=["MimicMedication"] + ) + + # Should only load the valid resource + bundle = result["medicationstatement"] + assert len(bundle.entry) == 1 diff --git a/tests/sandbox/test_request_constructors.py b/tests/sandbox/test_request_constructors.py index cd9e4ae5..31243bd0 100644 --- a/tests/sandbox/test_request_constructors.py +++ b/tests/sandbox/test_request_constructors.py @@ -6,7 +6,6 @@ ClinDocRequestConstructor, ) from healthchain.sandbox.workflows import Workflow -from healthchain.models.hooks.prefetch import Prefetch from healthchain.sandbox.base import ApiProtocol from healthchain.fhir import create_bundle @@ -29,8 +28,8 @@ def test_cds_request_constructor_validation(): """Test validation of workflows in CdsRequestConstructor""" constructor = CdsRequestConstructor() - # Create a prefetch object - prefetch = Prefetch(prefetch={"patient": create_bundle()}) + # Create a prefetch dict + prefetch = {"patient": create_bundle()} # Test with valid workflow valid_workflow = Workflow.patient_view @@ -46,15 +45,16 @@ def test_cds_request_constructor_validation(): def test_cds_request_constructor_type_error(): - """Test type error handling in CdsRequestConstructor""" + """Test validation error handling in CdsRequestConstructor""" constructor = CdsRequestConstructor() - # Test with invalid prefetch data type - should raise TypeError - with pytest.raises(TypeError): - # Not a Prefetch object - invalid_prefetch = {"patient": create_bundle()} + # Test with invalid workflow - should raise ValueError + with pytest.raises(ValueError): + # Invalid workflow + invalid_workflow = MagicMock() + invalid_workflow.value = "invalid-workflow" constructor.construct_request( - prefetch_data=invalid_prefetch, workflow=Workflow.patient_view + prefetch_data={"patient": create_bundle()}, workflow=invalid_workflow ) @@ -62,9 +62,9 @@ def test_cds_request_construction(): """Test request construction in CdsRequestConstructor""" constructor = CdsRequestConstructor() - # Create a bundle and prefetch + # Create a bundle and prefetch dict bundle = create_bundle() - prefetch = Prefetch(prefetch={"patient": bundle}) + prefetch = {"patient": bundle} # Construct a request request = constructor.construct_request( @@ -76,7 +76,7 @@ def test_cds_request_construction(): # Verify request properties assert request.hook == "patient-view" assert request.context.patientId == "test-patient-123" - assert request.prefetch == prefetch.prefetch + assert request.prefetch == prefetch def test_clindoc_request_constructor_init(): @@ -185,7 +185,7 @@ def test_cds_request_construction_with_custom_context(): """CdsRequestConstructor includes custom context parameters in request.""" constructor = CdsRequestConstructor() bundle = create_bundle() - prefetch = Prefetch(prefetch={"patient": bundle}) + prefetch = {"patient": bundle} # Test with custom context custom_context = {"patientId": "patient-123", "encounterId": "encounter-456"} @@ -201,7 +201,7 @@ def test_cds_request_construction_with_custom_context(): def test_cds_request_validates_workflow_for_clinical_doc(): """CdsRequestConstructor rejects ClinicalDocumentation workflows.""" constructor = CdsRequestConstructor() - prefetch = Prefetch(prefetch={"patient": create_bundle()}) + prefetch = {"patient": create_bundle()} # Should reject sign-note workflows with pytest.raises(ValueError, match="Invalid workflow"): diff --git a/tests/sandbox/test_sandbox_client.py b/tests/sandbox/test_sandbox_client.py index e0498d79..5dc22fb7 100644 --- a/tests/sandbox/test_sandbox_client.py +++ b/tests/sandbox/test_sandbox_client.py @@ -1,15 +1,16 @@ import pytest import json +from unittest.mock import Mock, patch from healthchain.sandbox import SandboxClient def test_load_from_registry_unknown_dataset(): """load_from_registry raises ValueError for unknown datasets.""" - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") with pytest.raises(ValueError, match="Unknown dataset"): - client.load_from_registry("nonexistent-dataset") + client.load_from_registry("nonexistent-dataset", data_dir="/test") def test_load_from_path_single_xml_file(tmp_path): @@ -19,13 +20,15 @@ def test_load_from_path_single_xml_file(tmp_path): cda_file.write_text("Test CDA") client = SandboxClient( - api_url="http://localhost:8000", endpoint="/notereader/fhir/", protocol="soap" + url="http://localhost:8000/notereader/fhir/", + workflow="sign-note-inpatient", + protocol="soap", ) result = client.load_from_path(str(cda_file)) assert result is client - assert len(client.request_data) == 1 + assert len(client.requests) == 1 def test_load_from_path_directory_with_pattern(tmp_path): @@ -36,12 +39,14 @@ def test_load_from_path_directory_with_pattern(tmp_path): (tmp_path / "other.txt").write_text("Not XML") client = SandboxClient( - api_url="http://localhost:8000", endpoint="/notereader/fhir/", protocol="soap" + url="http://localhost:8000/notereader/fhir/", + workflow="sign-note-inpatient", + protocol="soap", ) client.load_from_path(str(tmp_path), pattern="*.xml") - assert len(client.request_data) == 2 + assert len(client.requests) == 2 def test_load_from_path_directory_all_files(tmp_path): @@ -51,18 +56,22 @@ def test_load_from_path_directory_all_files(tmp_path): (tmp_path / "note2.xml").write_text("Note 2") client = SandboxClient( - api_url="http://localhost:8000", endpoint="/notereader/fhir/", protocol="soap" + url="http://localhost:8000/notereader/fhir/", + workflow="sign-note-inpatient", + protocol="soap", ) client.load_from_path(str(tmp_path)) - assert len(client.request_data) == 2 + assert len(client.requests) == 2 def test_load_from_path_error_handling(tmp_path): """load_from_path raises FileNotFoundError for nonexistent path.""" client = SandboxClient( - api_url="http://localhost:8000", endpoint="/notereader/fhir/", protocol="soap" + url="http://localhost:8000/notereader/fhir/", + workflow="sign-note-inpatient", + protocol="soap", ) with pytest.raises(FileNotFoundError): @@ -78,20 +87,45 @@ def test_load_free_text_generates_data(tmp_path): csv_file = tmp_path / "test.csv" csv_file.write_text("text\nSample discharge note\n") - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") + client = SandboxClient( + url="http://localhost:8000/test", + workflow="encounter-discharge", + ) client.load_free_text( csv_path=str(csv_file), column_name="text", + random_seed=42, + ) + assert len(client.requests) > 0 + + +def test_load_free_text_without_synthetic_data(tmp_path): + """load_free_text can generate data without synthetic resources.""" + # Create test CSV + csv_file = tmp_path / "test.csv" + csv_file.write_text("text\nSample discharge note\nAnother note\n") + + client = SandboxClient( + url="http://localhost:8000/test", workflow="encounter-discharge", + ) + + client.load_free_text( + csv_path=str(csv_file), + column_name="text", + generate_synthetic=False, random_seed=42, ) - assert len(client.request_data) > 0 + + assert len(client.requests) > 0 + # Verify request was created (but without checking prefetch content details) + assert client.requests[0].hook == "encounter-discharge" def test_send_requests_without_data(): """send_requests raises RuntimeError if no data is loaded.""" - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") with pytest.raises(RuntimeError, match="No requests to send"): client.send_requests() @@ -99,7 +133,7 @@ def test_send_requests_without_data(): def test_save_results_without_responses(): """save_results raises RuntimeError if no responses available.""" - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") with pytest.raises(RuntimeError, match="No responses to save"): client.save_results() @@ -107,15 +141,12 @@ def test_save_results_without_responses(): def test_get_status(): """get_status returns client status information.""" - client = SandboxClient( - api_url="http://localhost:8000", endpoint="/test", workflow="patient-view" - ) + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") status = client.get_status() assert "sandbox_id" in status - assert status["api_url"] == "http://localhost:8000" - assert status["endpoint"] == "/test" + assert status["url"] == "http://localhost:8000/test" assert status["protocol"] == "REST" assert status["workflow"] == "patient-view" assert status["requests_queued"] == 0 @@ -124,13 +155,12 @@ def test_get_status(): def test_repr(): """__repr__ returns meaningful string representation.""" - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") repr_str = repr(client) assert "SandboxClient" in repr_str - assert "http://localhost:8000" in repr_str - assert "/test" in repr_str + assert "http://localhost:8000/test" in repr_str def test_load_from_path_json_prefetch_file(tmp_path): @@ -142,52 +172,38 @@ def test_load_from_path_json_prefetch_file(tmp_path): prefetch_data = {"prefetch": {"patient": create_bundle().model_dump()}} json_file.write_text(json.dumps(prefetch_data)) - client = SandboxClient( - api_url="http://localhost:8000", endpoint="/test", workflow="patient-view" - ) + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") client.load_from_path(str(json_file)) - assert len(client.request_data) == 1 - assert client.request_data[0].hook == "patient-view" - - -def test_load_from_path_json_without_workflow_fails(tmp_path): - """load_from_path requires workflow for JSON Prefetch files.""" - json_file = tmp_path / "prefetch.json" - json_file.write_text('{"prefetch": {}}') - - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") - - with pytest.raises(ValueError, match="Workflow must be specified"): - client.load_from_path(str(json_file)) + assert len(client.requests) == 1 + assert client.requests[0].hook == "patient-view" def test_load_from_path_invalid_json_prefetch(tmp_path): - """load_from_path rejects malformed JSON Prefetch data.""" - json_file = tmp_path / "invalid.json" + """load_from_path processes JSON data for prefetch.""" + json_file = tmp_path / "data.json" json_file.write_text('{"not_prefetch": "data"}') - client = SandboxClient( - api_url="http://localhost:8000", endpoint="/test", workflow="patient-view" - ) + client = SandboxClient(url="http://localhost:8000/test", workflow="patient-view") - with pytest.raises(ValueError, match="not valid Prefetch format"): - client.load_from_path(str(json_file)) + # Should load the JSON data without error since we're using plain dicts now + client.load_from_path(str(json_file)) + assert len(client.requests) == 1 def test_save_results_distinguishes_protocols(tmp_path): """save_results uses correct file extension based on protocol.""" - from healthchain.models import Prefetch from healthchain.fhir import create_bundle - from healthchain.sandbox.workflows import Workflow # Test REST/JSON protocol rest_client = SandboxClient( - api_url="http://localhost:8000", endpoint="/test", protocol="rest" + url="http://localhost:8000/test", + workflow="patient-view", + protocol="rest", ) - prefetch = Prefetch(prefetch={"patient": create_bundle()}) - rest_client._construct_request(prefetch, Workflow.patient_view) + prefetch = {"patient": create_bundle()} + rest_client._construct_request(prefetch) rest_client.responses = [{"cards": []}] rest_dir = tmp_path / "rest" @@ -198,9 +214,11 @@ def test_save_results_distinguishes_protocols(tmp_path): # Test SOAP/XML protocol soap_client = SandboxClient( - api_url="http://localhost:8000", endpoint="/test", protocol="soap" + url="http://localhost:8000/test", + workflow="sign-note-inpatient", + protocol="soap", ) - soap_client._construct_request("test", Workflow.sign_note_inpatient) + soap_client._construct_request("test") soap_client.responses = ["data"] soap_dir = tmp_path / "soap" @@ -210,13 +228,265 @@ def test_save_results_distinguishes_protocols(tmp_path): assert len(list(soap_dir.glob("**/*.json"))) == 0 -def test_construct_request_requires_workflow_for_rest(): - """_construct_request raises ValueError if workflow missing for REST protocol.""" - client = SandboxClient(api_url="http://localhost:8000", endpoint="/test") - from healthchain.models import Prefetch +@pytest.mark.parametrize( + "workflow,protocol,should_fail", + [ + ("patient-view", "rest", False), # Valid CDS workflow with REST + ("encounter-discharge", "rest", False), # Valid CDS workflow with REST + ("sign-note-inpatient", "soap", False), # Valid ClinDoc workflow with SOAP + ("patient-view", "soap", True), # CDS workflow with SOAP - invalid + ("sign-note-inpatient", "rest", True), # ClinDoc workflow with REST - invalid + ], +) +def test_workflow_protocol_validation(workflow, protocol, should_fail): + """SandboxClient validates workflow-protocol compatibility at initialization.""" + if should_fail: + with pytest.raises(ValueError, match="not compatible"): + SandboxClient( + url="http://localhost:8000/test", + workflow=workflow, + protocol=protocol, + ) + else: + client = SandboxClient( + url="http://localhost:8000/test", + workflow=workflow, + protocol=protocol, + ) + assert client.workflow.value == workflow + + +def test_clear_requests(): + """clear_requests empties the request queue.""" + from healthchain.fhir import create_bundle + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + # Load some data + prefetch = {"patient": create_bundle()} + client._construct_request(prefetch) + assert len(client.requests) == 1 + + # Clear and verify + result = client.clear_requests() + assert result is client # Method chaining + assert len(client.requests) == 0 + + +def test_preview_requests_provides_metadata(): + """preview_requests returns summary metadata without sending requests.""" + from healthchain.fhir import create_bundle + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + # Load data + prefetch = {"patient": create_bundle()} + client._construct_request(prefetch) + client._construct_request(prefetch) + + # Preview without sending + previews = client.preview_requests() + + assert len(previews) == 2 + assert previews[0]["index"] == 0 + assert previews[0]["type"] == "CDSRequest" + assert ( + previews[0]["protocol"] == "REST" + ) # Protocol is uppercase in actual implementation + assert previews[0]["hook"] == "patient-view" + + +def test_preview_requests_respects_limit(): + """preview_requests limits returned results when limit specified.""" + from healthchain.fhir import create_bundle + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + # Load multiple requests + prefetch = {"patient": create_bundle()} + for _ in range(5): + client._construct_request(prefetch) + + previews = client.preview_requests(limit=2) + assert len(previews) == 2 + + +@pytest.mark.parametrize( + "format_type,check", + [ + ("raw", lambda data: isinstance(data, list)), + ("dict", lambda data: isinstance(data, list) and isinstance(data[0], dict)), + ("json", lambda data: isinstance(data, str) and json.loads(data)), + ], +) +def test_get_request_data_formats(format_type, check): + """get_request_data returns data in specified format.""" + from healthchain.fhir import create_bundle + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + prefetch = {"patient": create_bundle()} + client._construct_request(prefetch) + + data = client.get_request_data(format=format_type) + + assert check(data) + + +def test_get_request_data_invalid_format(): + """get_request_data raises ValueError for invalid format.""" + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + with pytest.raises(ValueError, match="Invalid format"): + client.get_request_data(format="invalid") + + +def test_context_manager_auto_saves_on_success(tmp_path): + """Context manager auto-saves results when responses exist and no exception.""" from healthchain.fhir import create_bundle - prefetch = Prefetch(prefetch={"patient": create_bundle()}) + with SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) as client: + prefetch = {"patient": create_bundle()} + client._construct_request(prefetch) + # Simulate responses + client.responses = [{"cards": []}] + + # Auto-save should have been triggered (saves to "./output/" by default) + + +def test_context_manager_no_save_without_responses(tmp_path): + """Context manager does not save if no responses generated.""" + with SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) as client: + # No requests or responses + pass + + # Should exit cleanly without trying to save + assert len(client.responses) == 0 + + +def test_context_manager_no_save_on_exception(): + """Context manager does not save if exception occurs.""" + with pytest.raises(RuntimeError): + with SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) as client: + client.responses = [{"cards": []}] + raise RuntimeError("Test exception") + + # Should exit without attempting save + + +@patch("httpx.Client") +def test_send_requests_rest_success(mock_client_class): + """send_requests successfully processes REST/CDS Hooks requests.""" + from healthchain.fhir import create_bundle + + # Mock successful response + mock_response = Mock() + mock_response.json.return_value = {"cards": [{"summary": "Test card"}]} + mock_response.raise_for_status = Mock() + + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client.__enter__ = Mock(return_value=mock_client) + mock_client.__exit__ = Mock(return_value=None) + mock_client_class.return_value = mock_client + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + prefetch = {"patient": create_bundle()} + client._construct_request(prefetch) + + responses = client.send_requests() + + assert len(responses) == 1 + assert responses[0]["cards"][0]["summary"] == "Test card" + assert mock_client.post.called + + +@patch("httpx.Client") +def test_send_requests_soap_success(mock_client_class): + """send_requests successfully processes SOAP/CDA requests.""" + # Mock successful response + mock_response = Mock() + mock_response.text = "Response" + mock_response.raise_for_status = Mock() + + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client.__enter__ = Mock(return_value=mock_client) + mock_client.__exit__ = Mock(return_value=None) + mock_client_class.return_value = mock_client + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="sign-note-inpatient", + protocol="soap", + ) + + client._construct_request("Test") + + responses = client.send_requests() + + assert len(responses) == 1 + # Response is processed by CdaResponse which may return empty dict if parsing fails + assert isinstance(responses[0], (str, dict)) + assert mock_client.post.called + + +@patch("httpx.Client") +def test_send_requests_handles_multiple_requests(mock_client_class): + """send_requests processes multiple queued requests sequentially.""" + from healthchain.fhir import create_bundle + + # Mock successful responses + mock_response = Mock() + mock_response.json.return_value = {"cards": []} + mock_response.raise_for_status = Mock() + + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client.__enter__ = Mock(return_value=mock_client) + mock_client.__exit__ = Mock(return_value=None) + mock_client_class.return_value = mock_client + + client = SandboxClient( + url="http://localhost:8000/test", + workflow="patient-view", + ) + + # Load multiple requests + prefetch = {"patient": create_bundle()} + client._construct_request(prefetch) + client._construct_request(prefetch) + client._construct_request(prefetch) + + responses = client.send_requests() - with pytest.raises(ValueError, match="Workflow must be specified for REST"): - client._construct_request(prefetch, None) + assert len(responses) == 3 + assert mock_client.post.call_count == 3 diff --git a/tests/sandbox/test_synthea_loader.py b/tests/sandbox/test_synthea_loader.py new file mode 100644 index 00000000..0910b91b --- /dev/null +++ b/tests/sandbox/test_synthea_loader.py @@ -0,0 +1,299 @@ +"""Tests for Synthea FHIR Patient dataset loader.""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from healthchain.sandbox.loaders.synthea import SyntheaFHIRPatientLoader + + +@pytest.fixture +def temp_synthea_data_dir(): + """Create temporary Synthea data directory structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def mock_patient_bundle(): + """Sample Synthea patient Bundle with multiple resource types.""" + return { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "a969c177-a995-7b89-7b6d-885214dfa253", + "name": [{"given": ["Alton"], "family": "Gutkowski"}], + "gender": "male", + "birthDate": "1980-01-01", + } + }, + { + "resource": { + "resourceType": "Condition", + "id": "cond-1", + "clinicalStatus": { + "coding": [ + { + "system": "http://terminology.hl7.org/CodeSystem/condition-clinical", + "code": "active", + } + ] + }, + "code": { + "coding": [ + {"system": "http://snomed.info/sct", "code": "44054006"} + ] + }, + "subject": { + "reference": "Patient/a969c177-a995-7b89-7b6d-885214dfa253" + }, + } + }, + { + "resource": { + "resourceType": "Condition", + "id": "cond-2", + "clinicalStatus": { + "coding": [ + { + "system": "http://terminology.hl7.org/CodeSystem/condition-clinical", + "code": "active", + } + ] + }, + "code": { + "coding": [ + {"system": "http://snomed.info/sct", "code": "38341003"} + ] + }, + "subject": { + "reference": "Patient/a969c177-a995-7b89-7b6d-885214dfa253" + }, + } + }, + { + "resource": { + "resourceType": "MedicationStatement", + "id": "med-1", + "status": "recorded", + "medication": { + "concept": { + "coding": [ + { + "system": "http://www.nlm.nih.gov/research/umls/rxnorm", + "code": "313782", + } + ] + } + }, + "subject": { + "reference": "Patient/a969c177-a995-7b89-7b6d-885214dfa253" + }, + } + }, + ], + } + + +def create_patient_file(data_dir: Path, filename: str, bundle: dict) -> Path: + """Helper to create patient Bundle JSON file.""" + file_path = data_dir / filename + with open(file_path, "w") as f: + json.dump(bundle, f) + return file_path + + +@pytest.mark.parametrize( + "patient_spec,filename", + [ + ( + {"patient_id": "a969c177-a995-7b89-7b6d-885214dfa253"}, + "Alton320_Gutkowski940_a969c177-a995-7b89-7b6d-885214dfa253.json", + ), + ( + { + "patient_file": "Alton320_Gutkowski940_a969c177-a995-7b89-7b6d-885214dfa253.json" + }, + "Alton320_Gutkowski940_a969c177-a995-7b89-7b6d-885214dfa253.json", + ), + ({}, "Patient1.json"), # Default: first file + ], +) +def test_synthea_loader_supports_multiple_file_specification_methods( + temp_synthea_data_dir, mock_patient_bundle, patient_spec, filename +): + """SyntheaFHIRPatientLoader supports patient_id, patient_file, and default loading.""" + create_patient_file(temp_synthea_data_dir, filename, mock_patient_bundle) + + loader = SyntheaFHIRPatientLoader() + result = loader.load(data_dir=str(temp_synthea_data_dir), **patient_spec) + + assert isinstance(result, dict) + assert "patient" in result and "condition" in result + # Returns Bundle objects + assert type(result["patient"]).__name__ == "Bundle" + assert len(result["patient"].entry) == 1 + assert len(result["condition"].entry) == 2 + + +def test_synthea_loader_filters_and_groups_resources_by_type( + temp_synthea_data_dir, mock_patient_bundle +): + """SyntheaFHIRPatientLoader filters by resource_types and groups into separate Bundles.""" + filename = "Patient1.json" + create_patient_file(temp_synthea_data_dir, filename, mock_patient_bundle) + + loader = SyntheaFHIRPatientLoader() + result = loader.load( + data_dir=str(temp_synthea_data_dir), + resource_types=["Condition", "MedicationStatement"], + ) + + # Only requested types included + assert set(result.keys()) == {"condition", "medicationstatement"} + assert len(result["condition"].entry) == 2 + assert len(result["medicationstatement"].entry) == 1 + + +@pytest.mark.parametrize("sample_size,expected_count", [(1, 1), (2, 2)]) +def test_synthea_loader_sampling_behavior( + temp_synthea_data_dir, mock_patient_bundle, sample_size, expected_count +): + """SyntheaFHIRPatientLoader samples specified number of resources per type.""" + create_patient_file(temp_synthea_data_dir, "Patient1.json", mock_patient_bundle) + + loader = SyntheaFHIRPatientLoader() + result = loader.load( + data_dir=str(temp_synthea_data_dir), + resource_types=["Condition"], + sample_size=sample_size, + ) + + assert len(result["condition"].entry) == expected_count + + +def test_synthea_loader_deterministic_sampling_with_seed( + temp_synthea_data_dir, mock_patient_bundle +): + """SyntheaFHIRPatientLoader produces consistent results with random_seed.""" + create_patient_file(temp_synthea_data_dir, "Patient1.json", mock_patient_bundle) + + loader = SyntheaFHIRPatientLoader() + result1 = loader.load( + data_dir=str(temp_synthea_data_dir), + resource_types=["Condition"], + sample_size=1, + random_seed=42, + ) + result2 = loader.load( + data_dir=str(temp_synthea_data_dir), + resource_types=["Condition"], + sample_size=1, + random_seed=42, + ) + + assert ( + result1["condition"].entry[0].resource.id + == result2["condition"].entry[0].resource.id + ) + + +@pytest.mark.parametrize( + "error_case,error_match", + [ + ({"data_dir": "/nonexistent"}, "Synthea data directory not found"), + ({"patient_id": "nonexistent-uuid"}, "No patient file found with ID"), + ({"patient_file": "nonexistent.json"}, "Patient file not found"), + ], +) +def test_synthea_loader_error_handling_for_missing_files( + temp_synthea_data_dir, mock_patient_bundle, error_case, error_match +): + """SyntheaFHIRPatientLoader raises clear errors for missing files and directories.""" + if "data_dir" not in error_case: + error_case["data_dir"] = str(temp_synthea_data_dir) + + loader = SyntheaFHIRPatientLoader() + with pytest.raises(FileNotFoundError, match=error_match): + loader.load(**error_case) + + +def test_synthea_loader_raises_error_for_multiple_matching_patient_ids( + temp_synthea_data_dir, mock_patient_bundle +): + """SyntheaFHIRPatientLoader raises ValueError when patient_id matches multiple files.""" + create_patient_file( + temp_synthea_data_dir, "Patient1_a969c177.json", mock_patient_bundle + ) + create_patient_file( + temp_synthea_data_dir, "Patient2_a969c177.json", mock_patient_bundle + ) + + loader = SyntheaFHIRPatientLoader() + with pytest.raises(ValueError, match="Multiple patient files found"): + loader.load(data_dir=str(temp_synthea_data_dir), patient_id="a969c177") + + +@pytest.mark.parametrize( + "invalid_bundle,error_match", + [ + ({"not": "a bundle"}, "is not a FHIR Bundle"), + ({"resourceType": "Patient"}, "is not a FHIR Bundle"), + ({"resourceType": "Bundle"}, "has no 'entry' field"), + ], +) +def test_synthea_loader_validates_bundle_structure( + temp_synthea_data_dir, invalid_bundle, error_match +): + """SyntheaFHIRPatientLoader validates Bundle structure and raises errors for invalid data.""" + create_patient_file(temp_synthea_data_dir, "Invalid.json", invalid_bundle) + + loader = SyntheaFHIRPatientLoader() + with pytest.raises(ValueError, match=error_match): + loader.load(data_dir=str(temp_synthea_data_dir)) + + +def test_synthea_loader_raises_error_for_nonexistent_resource_types( + temp_synthea_data_dir, mock_patient_bundle +): + """SyntheaFHIRPatientLoader raises error when requested resource_types don't exist.""" + create_patient_file(temp_synthea_data_dir, "Patient1.json", mock_patient_bundle) + + loader = SyntheaFHIRPatientLoader() + with pytest.raises(ValueError, match="No resources found for requested types"): + loader.load( + data_dir=str(temp_synthea_data_dir), + resource_types=["Observation", "Procedure"], # Not in bundle + ) + + +def test_synthea_loader_skips_resources_without_resource_type(temp_synthea_data_dir): + """SyntheaFHIRPatientLoader skips entries missing resourceType field.""" + invalid_bundle = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + {"resource": {"id": "no-type"}}, # Missing resourceType + { + "resource": { + "resourceType": "Patient", + "id": "patient-1", + "gender": "male", + "birthDate": "1980-01-01", + } + }, + ], + } + create_patient_file(temp_synthea_data_dir, "Patient1.json", invalid_bundle) + + loader = SyntheaFHIRPatientLoader() + result = loader.load(data_dir=str(temp_synthea_data_dir)) + + # Should only load valid Patient resource + assert "patient" in result + assert len(result["patient"].entry) == 1