# OMOP to MEDS ETL Pipeline

## Overview

This notebook implements a comprehensive Extract, Transform, Load (ETL) pipeline that converts OMOP Common Data Model (CDM) format data into the Medical Event Data Standard (MEDS) format. The pipeline is designed with robust error handling, validation, and data quality checks.

### What is MEDS?

MEDS (Medical Event Data Standard) is a standardized format for representing longitudinal electronic health record (EHR) data. It organizes medical events as time-series data, making it ideal for:
- Machine learning applications
- Temporal analysis
- Patient trajectory modeling
- Predictive analytics

### Pipeline Features

- **Automated Format Detection**: Handles both Parquet and Delta Lake formats
- **Comprehensive Validation**: Schema validation, null checks, and temporal consistency
- **Robust Error Handling**: Graceful failure recovery with detailed logging
- **Concept Mapping**: Automatic conversion from OMOP concept IDs to vocabulary codes
- **Temporal Validation**: Ensures events occur after birth and in logical sequence
- **Metadata Generation**: Creates MEDS-compliant metadata files

### OMOP Tables Processed

1. **person**: Patient demographics and birth information
2. **visit_occurrence**: Healthcare visits and encounters
3. **condition_occurrence**: Diagnoses and conditions
4. **drug_exposure**: Medication prescriptions and administrations
5. **procedure_occurrence**: Medical procedures
6. **measurement**: Laboratory results and vital signs

### Output Format

The pipeline produces MEDS-formatted data with the following schema:
- `PERSON_ID`: Patient identifier
- `time`: Event timestamp (aligned with visit start time when applicable)
- `code`: Standardized medical code (vocabulary/code format)
- `numeric_value`: Numeric measurement value (for lab results)
- `datetime_value`: Original event datetime

---

## 1. Import Required Libraries

Import all necessary libraries for data processing, logging, and configuration management.

In [None]:
import logging
import time
import os
import json
from typing import Dict, List, Optional
from dataclasses import dataclass

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType, 
    StructField, 
    StringType, 
    ArrayType,
    TimestampType,
    FloatType
)

## 2. Define Configuration and Utility Functions

This section defines:
- **MEDS Schema**: The standardized schema for MEDS output
- **ETLConfig**: Configuration dataclass for managing paths and table definitions
- **Logging Setup**: Structured logging for monitoring pipeline execution
- **Utility Functions**: Helper functions for validation and data handling

In [None]:
def get_meds_schema():
    """Define the standardized schema for MEDS data.
    
    Returns:
        StructType: PySpark schema definition for MEDS format
    """
    return StructType([
        StructField("PERSON_ID", StringType(), True),
        StructField("time", TimestampType(), True),
        StructField("datetime_value", TimestampType(), True),
        StructField("code", StringType(), True),
        StructField("numeric_value", FloatType(), True)
    ])

def setup_notebook_logging():
    """Configure logging for the notebook.
    
    Returns:
        logging.Logger: Configured logger instance
    """
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

logger = setup_notebook_logging()

@dataclass
class ETLConfig:
    """Configuration class for the ETL pipeline.
    
    Attributes:
        base_path: Path to cleaned OMOP data
        source_path: Path to raw OMOP source data (for concept table)
        output_path: Base path for MEDS output
        patient_tables: Dictionary mapping table names to required columns
    """
    base_path: str
    source_path: str
    output_path: str
    patient_tables: Dict[str, List[str]]
    
    @property
    def unsorted_dir(self) -> str:
        """Directory for intermediate unsorted data."""
        return os.path.join(self.output_path, "unsorted_data")
    
    @property
    def cleaned_dir(self) -> str:
        """Directory for final cleaned MEDS data."""
        return os.path.join(self.output_path, "cleaned_data")
    
    @property
    def metadata_dir(self) -> str:
        """Directory for MEDS metadata files."""
        return os.path.join(self.output_path, "metadata")
    
    def validate(self) -> None:
        """Validate that required paths exist and create output directories.
        
        Raises:
            ValueError: If required paths don't exist
        """
        try:
            # Note: In non-Databricks environments, replace dbutils calls with appropriate file system operations
            # For local file systems, use os.path.exists() and os.makedirs()
            
            # Validate source paths exist
            if not os.path.exists(self.base_path.replace('dbfs:', '/dbfs')):
                raise ValueError(f"Base path does not exist: {self.base_path}")
            
            # Create output directories
            for path in [self.unsorted_dir, self.cleaned_dir, self.metadata_dir]:
                os.makedirs(path.replace('dbfs:', '/dbfs'), exist_ok=True)
                
        except Exception as e:
            raise ValueError(f"Path validation failed: {str(e)}")

def create_config() -> ETLConfig:
    """Create and validate ETL configuration.
    
    Modify the paths below to match your environment.
    
    Returns:
        ETLConfig: Validated configuration object
    """
    config = ETLConfig(
        # REPLACE THESE PATHS WITH YOUR ACTUAL DATA LOCATIONS
        base_path="/path/to/your/cleaned/omop/data/",
        source_path="/path/to/your/raw/omop/data/",
        output_path="/path/to/your/meds/output/",
        
        # Define which columns to extract from each OMOP table
        patient_tables={
            "person": [
                "PERSON_ID", 
                "YEAR_OF_BIRTH", 
                "MONTH_OF_BIRTH", 
                "DAY_OF_BIRTH", 
                "BIRTH_DATETIME"
            ],
            "condition_occurrence": [
                "PERSON_ID", 
                "CONDITION_CONCEPT_ID", 
                "CONDITION_START_DATETIME", 
                "VISIT_OCCURRENCE_ID"
            ],
            "drug_exposure": [
                "PERSON_ID", 
                "DRUG_CONCEPT_ID", 
                "DRUG_EXPOSURE_START_DATETIME", 
                "VISIT_OCCURRENCE_ID"
            ],
            "procedure_occurrence": [
                "PERSON_ID", 
                "PROCEDURE_CONCEPT_ID", 
                "PROCEDURE_DATETIME", 
                "VISIT_OCCURRENCE_ID"
            ],
            "visit_occurrence": [
                "PERSON_ID", 
                "VISIT_CONCEPT_ID", 
                "VISIT_OCCURRENCE_ID", 
                "VISIT_START_DATETIME"
            ],
            "measurement": [
                "PERSON_ID", 
                "MEASUREMENT_CONCEPT_ID", 
                "MEASUREMENT_DATETIME", 
                "VALUE_AS_NUMBER", 
                "VISIT_OCCURRENCE_ID"
            ]
        }
    )
    config.validate()
    return config

# Initialize configuration
config = create_config()
logger.info("Configuration initialized successfully")

## 3. Data Validation Functions

These functions ensure data quality throughout the ETL process by:
- Validating DataFrame schemas
- Checking for required columns
- Computing and logging data quality metrics
- Detecting empty datasets

In [None]:
def validate_dataframe(df: DataFrame, table_name: str, required_columns: List[str]) -> None:
    """Validate DataFrame schema and contents.
    
    Args:
        df: DataFrame to validate
        table_name: Name of the table being validated
        required_columns: List of columns that must be present
        
    Raises:
        ValueError: If validation fails
    """
    # Check for required columns
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        raise ValueError(f"Missing required columns in {table_name}: {missing_columns}")
    
    # Check for empty DataFrame
    if df.rdd.isEmpty():
        raise ValueError(f"Empty DataFrame for {table_name}")
    
    # Log data quality metrics
    row_count = df.count()
    null_counts = {col: df.filter(F.col(col).isNull()).count() for col in required_columns}
    
    logger.info(f"Table {table_name} statistics:")
    logger.info(f"  Total rows: {row_count:,}")
    logger.info(f"  Null counts: {null_counts}")

def verify_path_exists(spark: SparkSession, path: str) -> bool:
    """Verify if a path exists in the file system.
    
    Args:
        spark: SparkSession instance
        path: Path to verify
        
    Returns:
        bool: True if path exists, False otherwise
    """
    try:
        # For Databricks: use dbutils.fs.ls(path)
        # For standard Spark: use os.path.exists()
        return os.path.exists(path.replace('dbfs:', '/dbfs'))
    except Exception:
        return False

## 4. Concept Mapping Functions

OMOP uses concept IDs to represent medical terminology. These functions:
- Load the OMOP concept vocabulary
- Create mappings from concept IDs to standardized codes (e.g., SNOMED/123456)
- Enable conversion to MEDS format with human-readable codes

In [None]:
def create_concept_maps(spark: SparkSession, config: ETLConfig) -> Dict[str, str]:
    """Create concept ID to code mappings with error handling.
    
    This function reads the OMOP concept table and creates a dictionary
    mapping concept IDs to vocabulary-prefixed codes (e.g., "SNOMED/12345").
    
    Args:
        spark: SparkSession instance
        config: ETL configuration
        
    Returns:
        Dict[str, str]: Dictionary mapping concept IDs to codes
        
    Raises:
        ValueError: If concept table cannot be read
    """
    try:
        concept_df = read_table_safely(
            spark,
            "concept", 
            config.source_path,
            ["CONCEPT_ID", "VOCABULARY_ID", "CONCEPT_CODE"]
        )
        
        if concept_df is None:
            raise ValueError("Failed to read concept table")

        # Create codes in format "VOCABULARY/CODE"
        concept_df = concept_df.withColumn(
            "code", 
            F.concat_ws("/", F.col("VOCABULARY_ID"), F.col("CONCEPT_CODE"))
        )
        
        # Convert to dictionary efficiently using broadcast
        concept_map = dict(
            concept_df.select("CONCEPT_ID", "code")
                     .rdd
                     .collectAsMap()
        )
        
        logger.info(f"Created concept map with {len(concept_map):,} entries")
        return concept_map
        
    except Exception as e:
        logger.error(f"Error creating concept maps: {str(e)}")
        raise

## 5. Metadata Generation

Generate MEDS-compliant metadata files that document:
- Dataset name and version
- ETL pipeline version
- MEDS format version
- Creation timestamp

This metadata is essential for data provenance and reproducibility.

In [None]:
def generate_metadata(spark: SparkSession, config: ETLConfig):
    """Generate metadata files for the MEDS dataset.
    
    Creates a JSON file with dataset information following MEDS standards.
    
    Args:
        spark: SparkSession instance
        config: ETL configuration
    """
    try:
        logger.info("Generating metadata...")
        
        # Create dataset metadata
        dataset_metadata = {
            "dataset_name": "MEDS Dataset",
            "dataset_version": "1.0",
            "etl_name": "omop_to_meds_etl",
            "etl_version": "1.0",
            "meds_version": "1.0",
            "created_at": time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
        }
        
        # Write metadata file
        metadata_path = os.path.join(config.metadata_dir, "dataset.json")
        with open(metadata_path.replace('dbfs:', '/dbfs'), 'w') as f:
            json.dump(dataset_metadata, f, indent=2)
        
        logger.info(f"Metadata written to {metadata_path}")
        
    except Exception as e:
        logger.error(f"Error generating metadata: {str(e)}")
        raise

## 6. Data Reading Functions

These functions provide robust data reading capabilities:
- **Format Detection**: Automatically detect Parquet or Delta Lake formats
- **Safe Reading**: Handle missing files and invalid formats gracefully
- **Column Validation**: Ensure required columns are present
- **Retry Logic**: Attempt multiple read strategies before failing

In [None]:
def detect_table_format(spark: SparkSession, path: str) -> str:
    """Detect the format of data at the given path.
    
    Args:
        spark: SparkSession instance
        path: Path to the data
        
    Returns:
        str: Format type ('delta' or 'parquet')
        
    Raises:
        ValueError: If format cannot be detected
    """
    try:
        # Try to read as Delta first
        spark.read.format("delta").load(path).limit(1).count()
        return "delta"
    except Exception as e:
        if "Incompatible format" in str(e) or "not a Delta table" in str(e):
            # Try to read as Parquet
            try:
                spark.read.parquet(path).limit(1).count()
                return "parquet"
            except:
                pass
        raise ValueError(f"Unable to detect readable format for {path}")

def read_table_safely(
    spark: SparkSession,
    table_name: str,
    base_path: str,
    required_columns: List[str]
) -> Optional[DataFrame]:
    """Safely read a table with validation and error handling.
    
    Attempts to read data in multiple formats (Parquet, Delta) and
    validates that all required columns are present.
    
    Args:
        spark: SparkSession instance
        table_name: Name of the table to read
        base_path: Base directory containing tables
        required_columns: List of columns that must be present
    
    Returns:
        Optional[DataFrame]: DataFrame with required columns, or None if read fails
    """
    table_path = os.path.join(base_path, table_name)
    
    if not verify_path_exists(spark, table_path):
        logger.warning(f"Table path does not exist: {table_path}")
        return None
        
    try:
        # Try reading as parquet first
        df = spark.read.parquet(table_path)
        logger.info(f"Successfully read {table_name} as parquet")
    except Exception as e1:
        try:
            # Try reading as delta format
            df = spark.read.format("delta").load(table_path)
            logger.info(f"Successfully read {table_name} as delta")
        except Exception as e2:
            logger.error(f"Failed to read {table_name} in both formats")
            logger.error(f"  Parquet error: {str(e1)}")
            logger.error(f"  Delta error: {str(e2)}")
            return None
    
    # Validate columns
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        logger.error(f"Missing required columns in {table_name}: {missing_columns}")
        return None
        
    return df.select(*required_columns)

## 7. Data Transformation Functions

These functions handle the core transformation logic:

### Person Table Processing
- Creates birth datetime from year/month/day components
- Assigns birth event code (SNOMED/184099003 = "Live birth")

### Clinical Table Processing
- Maps OMOP concept IDs to vocabulary codes
- Aligns event times with visit start times
- Preserves original datetime values
- Handles numeric measurements (lab values)

### Temporal Alignment
All clinical events are aligned to their associated visit's start time, which provides:
- Consistent temporal ordering
- Simplified sequential modeling
- Reduced temporal granularity for privacy

In [None]:
def cast_to_datetime(df: DataFrame, time_columns: List[str]) -> DataFrame:
    """Cast time columns to datetime with validation.
    
    Attempts to cast multiple columns to timestamp, using coalesce
    to handle nulls gracefully.
    
    Args:
        df: Input DataFrame
        time_columns: List of column names to try casting
        
    Returns:
        Column expression with datetime value
    """
    valid_time_columns = [col_name for col_name in time_columns if col_name in df.columns]
    
    if not valid_time_columns:
        return F.lit(None).cast("timestamp")
    
    time_col = None
    for col_name in valid_time_columns:
        datetime_col = F.col(col_name).cast("timestamp")
        if time_col is None:
            time_col = datetime_col
        else:
            time_col = F.coalesce(time_col, datetime_col)
    
    return time_col

def process_person_table(df: DataFrame) -> DataFrame:
    """Process person table to create birth events.
    
    Transforms person demographics into MEDS format with birth events.
    Birth date is constructed from year, month, and day components.
    
    Args:
        df: Person DataFrame
        
    Returns:
        DataFrame in MEDS format with birth events
    """
    # Filter records with valid birth year
    df = df.filter(F.col("YEAR_OF_BIRTH").isNotNull())
    
    # Create birth datetime from components
    df = df.withColumn(
        "time",
        F.to_timestamp(
            F.concat_ws(
                '-',
                F.col("YEAR_OF_BIRTH"),
                F.lpad(F.coalesce(F.col("MONTH_OF_BIRTH"), F.lit(1)), 2, '0'),
                F.lpad(F.coalesce(F.col("DAY_OF_BIRTH"), F.lit(1)), 2, '0')
            ),
            "yyyy-MM-dd"
        )
    )
    
    # Set datetime_value from birth datetime if available
    df = df.withColumn("datetime_value", F.col("BIRTH_DATETIME").cast("timestamp"))
    
    # Assign birth event code (SNOMED: Live birth)
    df = df.withColumn("code", F.lit("SNOMED/184099003"))
    
    # No numeric value for birth events
    df = df.withColumn("numeric_value", F.lit(None).cast("float"))
    
    return df

def process_clinical_table(
    df: DataFrame,
    table_name: str,
    visit_occurrence_df: DataFrame,
    concept_id_map: Dict[str, str]
) -> DataFrame:
    """Process clinical tables (conditions, drugs, procedures, measurements).
    
    Transforms clinical event tables into MEDS format by:
    1. Joining with visit occurrence to get visit start times
    2. Mapping concept IDs to vocabulary codes
    3. Preserving original datetime values
    4. Handling numeric measurements
    
    Args:
        df: Input clinical DataFrame
        table_name: Name of the table being processed
        visit_occurrence_df: Visit occurrence DataFrame for time alignment
        concept_id_map: Dictionary mapping concept IDs to codes
        
    Returns:
        DataFrame in MEDS format
    """
    
    # Handle visit occurrence table differently
    if table_name == "visit_occurrence":
        time_col = F.col("VISIT_START_DATETIME").cast("timestamp")
        datetime_value_col = time_col
    else:
        # Join with visit occurrence to align event times
        if "VISIT_OCCURRENCE_ID" in df.columns:
            visit_df = visit_occurrence_df.withColumnRenamed(
                "VISIT_START_DATETIME", 
                "visit_start_time"
            )
            df = df.join(
                visit_df,
                on="VISIT_OCCURRENCE_ID",
                how="left"
            )
            time_col = F.col("visit_start_time").cast("timestamp")
        else:
            time_col = F.lit(None).cast("timestamp")
        
        # Preserve original datetime for clinical events
        datetime_columns = [col for col in df.columns if "_DATETIME" in col]
        datetime_value_col = cast_to_datetime(df, datetime_columns)
    
    # Add time columns
    df = df.withColumn("time", time_col)
    df = df.withColumn("datetime_value", datetime_value_col)
    
    # Map concept IDs to vocabulary codes
    concept_col = next(col for col in df.columns if "CONCEPT_ID" in col)
    concept_map_df = spark.createDataFrame(
        [(k, v) for k, v in concept_id_map.items()],
        ["concept_id", "code"]
    )
    
    df = df.join(
        F.broadcast(concept_map_df),
        df[concept_col] == concept_map_df["concept_id"],
        "left"
    )
    
    # Handle numeric values (measurements have VALUE_AS_NUMBER)
    if table_name == "measurement":
        df = df.withColumn(
            "numeric_value",
            F.col("VALUE_AS_NUMBER").cast("float")
        )
    else:
        df = df.withColumn(
            "numeric_value",
            F.lit(None).cast("float")
        )
    
    # Clean up temporary columns
    if "visit_start_time" in df.columns:
        df = df.drop("visit_start_time")
    
    return df

## 8. Table Processing Pipeline

This function orchestrates the transformation of individual OMOP tables into MEDS format.

### Processing Steps:
1. **Read**: Load table with required columns
2. **Transform**: Apply table-specific transformation logic
3. **Validate**: Check data quality and completeness
4. **Write**: Save to intermediate storage with retry logic

### Error Handling:
- Graceful failure: Individual table failures don't stop the pipeline
- Retry logic: Up to 3 attempts for write operations
- Detailed logging: All errors are logged for debugging

In [None]:
def process_table(
    spark: SparkSession,
    table_name: str,
    config: ETLConfig,
    concept_id_map: Dict[str, str],
    visit_occurrence_df: DataFrame
) -> None:
    """Process a patient-related table with enhanced error handling.
    
    Args:
        spark: SparkSession instance
        table_name: Name of the table to process
        config: ETL configuration
        concept_id_map: Mapping of concept IDs to codes
        visit_occurrence_df: Visit occurrence DataFrame for joins
    """
    try:
        logger.info(f"Processing {table_name}...")
        
        # Read table with safety checks
        required_columns = config.patient_tables[table_name]
        df = read_table_safely(spark, table_name, config.base_path, required_columns)
        
        if df is None:
            logger.error(f"Skipping {table_name} due to read/validation errors")
            return
            
        # Process based on table type
        if table_name == "person":
            df = process_person_table(df)
        else:
            df = process_clinical_table(df, table_name, visit_occurrence_df, concept_id_map)
        
        # Prepare output path
        output_path = os.path.join(config.unsorted_dir, f"{table_name}.parquet")
        
        # Write with retry logic
        max_retries = 3
        for attempt in range(max_retries):
            try:
                df.write.mode("overwrite").parquet(output_path, compression='snappy')
                logger.info(f"Successfully wrote {table_name} to {output_path}")
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    raise
                logger.warning(f"Write attempt {attempt + 1} failed, retrying...")
                time.sleep(5)
        
        logger.info(f"Successfully processed {table_name}")
        
    except Exception as e:
        logger.error(f"Error processing {table_name}: {str(e)}")
        raise

## 9. Data Cleaning and Validation

After initial transformation, this function applies comprehensive cleaning:

### Cleaning Operations:
1. **Person Validation**: Remove events for non-existent patients
2. **Temporal Validation**: Ensure events occur after birth
3. **Datetime Consistency**: Fix datetime values that violate temporal logic
4. **Quality Metrics**: Calculate and log retention rates

### Temporal Validation Rules:
- Events must occur after patient birth
- `datetime_value` must not be before `time`
- Events should maintain logical temporal ordering

These validations ensure the data is suitable for time-series modeling.

In [None]:
def prepare_person_data(spark: SparkSession, config: ETLConfig) -> DataFrame:
    """Prepare person data for temporal validation.
    
    Creates a DataFrame with person IDs and birth dates for filtering.
    
    Args:
        spark: SparkSession instance
        config: ETL configuration
        
    Returns:
        DataFrame with person IDs and birth dates
        
    Raises:
        ValueError: If person data cannot be read
    """
    person_df = read_table_safely(
        spark,
        "person",
        config.base_path,
        ["PERSON_ID", "YEAR_OF_BIRTH", "MONTH_OF_BIRTH", "DAY_OF_BIRTH"]
    )
    
    if person_df is None:
        raise ValueError("Failed to read person data")
    
    return person_df.withColumn(
        "birthdate",
        F.to_timestamp(
            F.concat_ws(
                '-',
                F.col("YEAR_OF_BIRTH"),
                F.lpad(F.coalesce(F.col("MONTH_OF_BIRTH"), F.lit(1)), 2, '0'),
                F.lpad(F.coalesce(F.col("DAY_OF_BIRTH"), F.lit(1)), 2, '0')
            ),
            "yyyy-MM-dd"
        )
    )

def apply_temporal_validations(df: DataFrame) -> DataFrame:
    """Apply temporal validations to MEDS data.
    
    Ensures temporal consistency by:
    1. Filtering events that occur before birth
    2. Correcting datetime_value when it precedes time
    3. Maintaining event ordering within patient timelines
    
    Args:
        df: Input MEDS DataFrame
        
    Returns:
        Temporally validated DataFrame
    """
    # Filter events before birth
    df = df.filter(F.col("time") >= F.col("birthdate"))
    
    # Ensure datetime_value is not before time
    df = df.withColumn(
        "datetime_value",
        F.when(
            F.col("datetime_value") < F.col("time"), 
            F.col("time")
        ).otherwise(F.col("datetime_value"))
    )
    
    # Handle temporal ordering between events
    window_spec = Window.partitionBy("PERSON_ID").orderBy("time")
    df = df.withColumn("next_event_time", F.lead("time").over(window_spec))
    
    # Correct datetime_value if it exceeds next event time
    df = df.withColumn(
        "datetime_value",
        F.when(
            (F.col("next_event_time").isNotNull()) & 
            (F.col("datetime_value") > F.col("next_event_time")),
            F.col("time")
        ).otherwise(F.col("datetime_value"))
    )
    
    # Remove temporary columns
    return df.drop("next_event_time", "birthdate")

def clean_meds_data(spark: SparkSession, config: ETLConfig) -> int:
    """Clean MEDS data with comprehensive validation.
    
    Final cleaning step that:
    1. Loads all transformed data
    2. Joins with person data to filter valid patients
    3. Applies temporal validations
    4. Writes final cleaned dataset
    5. Validates written data
    
    Args:
        spark: SparkSession instance
        config: ETL configuration
        
    Returns:
        int: Final record count
        
    Raises:
        ValueError: If cleaning fails
    """
    try:
        logger.info("Starting MEDS data cleaning...")
        
        # Load unsorted data
        unsorted_path = config.unsorted_dir
        if not verify_path_exists(spark, unsorted_path):
            raise ValueError(f"Unsorted data directory does not exist: {unsorted_path}")
        
        # Read with explicit schema
        schema = get_meds_schema()
        meds_df = spark.read.schema(schema).parquet(unsorted_path)
        
        initial_count = meds_df.count()
        logger.info(f"Initial event count: {initial_count:,}")
        
        # Load person data for validation
        person_df = prepare_person_data(spark, config)
        
        # Join with person to keep only valid patients
        meds_df = (meds_df.join(
            person_df,
            meds_df.PERSON_ID == person_df.PERSON_ID,
            "inner"
        ).drop(person_df.PERSON_ID))
        
        # Apply temporal validations
        meds_df = apply_temporal_validations(meds_df)
        
        # Calculate statistics
        final_count = meds_df.count()
        records_removed = initial_count - final_count
        retention_rate = (final_count / initial_count * 100) if initial_count > 0 else 0
        
        logger.info(f"Records removed: {records_removed:,}")
        logger.info(f"Retention rate: {retention_rate:.2f}%")
        
        # Write cleaned data with retry logic
        max_retries = 3
        for attempt in range(max_retries):
            try:
                meds_df.write.mode("overwrite").parquet(
                    config.cleaned_dir,
                    compression='snappy'
                )
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    raise
                logger.warning(f"Write attempt {attempt + 1} failed, retrying...")
                time.sleep(5)
        
        logger.info(f"Successfully wrote {final_count:,} records to {config.cleaned_dir}")
        
        # Validate written data
        validation_df = spark.read.schema(schema).parquet(config.cleaned_dir)
        validation_count = validation_df.count()
        
        if validation_count != final_count:
            logger.error(
                f"Validation failed: Written records ({validation_count:,}) != "
                f"Expected records ({final_count:,})"
            )
        else:
            logger.info("Data validation successful")
        
        return final_count
        
    except Exception as e:
        logger.error(f"Error during MEDS data cleaning: {str(e)}")
        raise

## 10. Main Pipeline Execution

This is the main orchestration function that executes the complete ETL pipeline.

### Execution Flow:

1. **Initialization**
   - Load configuration
   - Validate paths
   - Create output directories

2. **Setup Phase**
   - Create concept ID mappings
   - Load and cache visit occurrence data

3. **Transformation Phase**
   - Process each OMOP table sequentially
   - Track successful and failed tables

4. **Cleaning Phase**
   - Consolidate all transformed data
   - Apply validation rules
   - Write final output

5. **Finalization**
   - Generate metadata
   - Log statistics
   - Clean up resources

### Error Handling:
- Individual table failures don't stop the pipeline
- Failed tables are logged for review
- Pipeline requires at least one successful table to proceed

In [None]:
def main():
    """Main execution function with comprehensive error handling.
    
    Orchestrates the complete OMOP to MEDS ETL pipeline.
    """
    try:
        logger.info("=" * 80)
        logger.info("Starting OMOP to MEDS ETL Pipeline")
        logger.info("=" * 80)
        start_time = time.time()
        
        # Initialize Spark session
        spark = SparkSession.builder.appName("OMOP_to_MEDS_ETL").getOrCreate()
        
        # Initialize and validate config
        config = create_config()
        
        # Verify all required paths exist
        required_paths = [config.base_path, config.source_path]
        for path in required_paths:
            if not verify_path_exists(spark, path):
                raise ValueError(f"Required path does not exist: {path}")
        
        logger.info("Configuration validated successfully")
        
        # Phase 1: Create concept mappings
        logger.info("\n" + "-" * 80)
        logger.info("Phase 1: Creating concept mappings")
        logger.info("-" * 80)
        concept_id_map = create_concept_maps(spark, config)
        
        # Phase 2: Load and cache visit occurrence
        logger.info("\n" + "-" * 80)
        logger.info("Phase 2: Loading visit occurrence data")
        logger.info("-" * 80)
        visit_occurrence_df = read_table_safely(
            spark,
            "visit_occurrence",
            config.base_path,
            ["VISIT_OCCURRENCE_ID", "VISIT_START_DATETIME"]
        )
        
        if visit_occurrence_df is None:
            raise ValueError("Failed to read visit_occurrence table - cannot proceed")
            
        visit_occurrence_df = visit_occurrence_df.cache()
        logger.info(f"Visit occurrence cached: {visit_occurrence_df.count():,} records")
        
        # Phase 3: Process all tables
        logger.info("\n" + "-" * 80)
        logger.info("Phase 3: Processing OMOP tables")
        logger.info("-" * 80)
        
        successful_tables = []
        failed_tables = []
        
        for table_name in config.patient_tables.keys():
            try:
                process_table(
                    spark,
                    table_name,
                    config,
                    concept_id_map,
                    visit_occurrence_df
                )
                successful_tables.append(table_name)
            except Exception as e:
                logger.error(f"Failed to process {table_name}: {str(e)}")
                failed_tables.append(table_name)
                continue
        
        if not successful_tables:
            raise ValueError("No tables were processed successfully")
        
        logger.info(f"\nSuccessfully processed: {len(successful_tables)} tables")
        logger.info(f"Failed to process: {len(failed_tables)} tables")
        
        # Phase 4: Clean and validate data
        logger.info("\n" + "-" * 80)
        logger.info("Phase 4: Cleaning and validating MEDS data")
        logger.info("-" * 80)
        final_count = clean_meds_data(spark, config)
        
        # Phase 5: Generate metadata
        logger.info("\n" + "-" * 80)
        logger.info("Phase 5: Generating metadata")
        logger.info("-" * 80)
        generate_metadata(spark, config)
        
        # Clean up
        visit_occurrence_df.unpersist()
        
        # Final statistics
        end_time = time.time()
        duration = end_time - start_time
        
        logger.info("\n" + "=" * 80)
        logger.info("Pipeline Completed Successfully!")
        logger.info("=" * 80)
        logger.info(f"Duration: {duration:.2f} seconds")
        logger.info(f"Successfully processed tables: {successful_tables}")
        if failed_tables:
            logger.warning(f"Failed tables: {failed_tables}")
        logger.info(f"Final event count: {final_count:,}")
        logger.info(f"Output location: {config.cleaned_dir}")
        logger.info("=" * 80)
        
    except Exception as e:
        logger.error("\n" + "=" * 80)
        logger.error("Pipeline Failed!")
        logger.error("=" * 80)
        logger.error(f"Error: {str(e)}")
        logger.error("=" * 80)
        raise
    finally:
        logger.info("Pipeline execution finished")

## 11. Execute Pipeline

Run the complete ETL pipeline. Monitor the logs for progress and any errors.

In [None]:
if __name__ == "__main__":
    main()

## 12. Verify Output (Optional)

After the pipeline completes, you can verify the output data.

In [None]:
# Uncomment to verify output

# # Load the cleaned MEDS data
# spark = SparkSession.builder.appName("MEDS_Verification").getOrCreate()
# config = create_config()

# meds_df = spark.read.parquet(config.cleaned_dir)

# # Display basic statistics
# print("\nMEDS Data Statistics:")
# print(f"Total events: {meds_df.count():,}")
# print(f"Unique patients: {meds_df.select('PERSON_ID').distinct().count():,}")
# print(f"Unique codes: {meds_df.select('code').distinct().count():,}")

# # Show sample records
# print("\nSample records:")
# meds_df.show(5, truncate=False)

# # Check schema
# print("\nSchema:")
# meds_df.printSchema()

---

## Notes and Best Practices

### Data Quality Considerations

**What the Pipeline Handles:**
- Temporal alignment with visit start times
- Birth date validation
- Concept ID to code mapping
- Null value handling
- Format detection (Parquet/Delta)

**What You Should Review:**
- Failed table processing (check logs)
- Data retention rates (should be >95%)
- Missing concept mappings
- Unusual temporal patterns

### Performance Optimization

1. **Caching**: Visit occurrence data is cached since it's used for all clinical tables
2. **Broadcasting**: Concept maps use broadcast joins for efficiency
3. **Compression**: Output uses Snappy compression for balance of speed and size
4. **Partitioning**: Consider repartitioning large datasets before final write

### Customization Points

**Adding New Tables:**
1. Add table name and columns to `patient_tables` in `ETLConfig`
2. Update `process_clinical_table` if special handling is needed
3. Ensure concept ID column follows naming convention

**Modifying Time Alignment:**
- Current: All events aligned to visit start time
- Alternative: Use original event datetime (modify `process_clinical_table`)

**Custom Validation Rules:**
- Add to `apply_temporal_validations` function
- Consider domain-specific constraints (e.g., age limits)

### Troubleshooting

**Common Issues:**

1. **"Table path does not exist"**
   - Verify paths in configuration
   - Check file permissions
   - Ensure OMOP data is in expected format

2. **"Missing required columns"**
   - Verify OMOP schema matches expected structure
   - Update `patient_tables` configuration
   - Check for case sensitivity in column names

3. **Low retention rate (<90%)**
   - Review temporal validation logic
   - Check for invalid birth dates
   - Verify person table completeness

4. **Out of memory errors**
   - Increase Spark executor memory
   - Process tables in smaller batches
   - Add intermediate caching/checkpointing

### Environment Considerations

**Databricks:**
- Uses `dbutils` for file operations
- DBFS paths prefixed with `dbfs:/`
- Delta Lake format supported natively

**Standard Spark:**
- Replace `dbutils` calls with OS operations
- Use HDFS or local file paths
- May need Delta Lake package added

**Local Development:**
- Use small sample datasets
- Set Spark to local mode
- Consider using Parquet instead of Delta

---

## Output Format

### MEDS Schema

```
PERSON_ID: string - Patient identifier
time: timestamp - Event time (aligned to visit start)
datetime_value: timestamp - Original event datetime
code: string - Medical code (format: VOCABULARY/CODE)
numeric_value: float - Numeric measurement value (null for non-measurements)
```

### Example Records

```
PERSON_ID | time                | datetime_value      | code                | numeric_value
----------|--------------------|--------------------|---------------------|---------------
12345     | 1980-06-15 00:00:00| 1980-06-15 08:30:00| SNOMED/184099003   | null
12345     | 2020-01-10 09:00:00| 2020-01-10 09:15:00| SNOMED/38341003    | null
12345     | 2020-01-10 09:00:00| 2020-01-10 10:30:00| LOINC/2093-3       | 7.2
```

---

## Citation

If you use this ETL pipeline in your research, please cite:

```
[Your Name]. (2025). OMOP to MEDS ETL Pipeline.
GitHub: [Your Repository URL]
```

---

## License

Specify your license here (e.g., MIT, Apache 2.0, etc.)

---