# OMOP to MEDS ETL Process

## Overview

This notebook documents the ETL (Extract, Transform, Load) process used to convert OMOP Common Data Model format to Medical Event Data Standard (MEDS) format.

### What Was Done

The ETL process consisted of four main phases:

1. **Extraction**: Load cleaned OMOP tables and concept mappings
2. **Transformation**: Convert OMOP structure to MEDS time-series format
3. **Temporal Alignment**: Align event times to visit start times
4. **Validation and Cleaning**: Apply temporal consistency checks

### MEDS Format

MEDS represents medical data as time-series events with:
- `PERSON_ID`: Patient identifier
- `time`: Event timestamp (aligned to visit start)
- `code`: Medical code in VOCABULARY/CODE format
- `numeric_value`: Numeric measurements (for lab results)
- `datetime_value`: Original event datetime

---

## 1. Import Libraries

In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, FloatType
import logging
from typing import Dict, List

# Initialize Spark
spark = SparkSession.builder.appName("OMOP_to_MEDS_ETL").getOrCreate()

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

## 2. Define MEDS Schema

The MEDS schema standardizes how medical events are represented.

In [None]:
def get_meds_schema():
    """Define the MEDS data schema."""
    return StructType([
        StructField("PERSON_ID", StringType(), True),
        StructField("time", TimestampType(), True),
        StructField("datetime_value", TimestampType(), True),
        StructField("code", StringType(), True),
        StructField("numeric_value", FloatType(), True)
    ])

MEDS_SCHEMA = get_meds_schema()
print("MEDS schema defined")
print(MEDS_SCHEMA)

## 3. Configuration

Define which OMOP tables to process and required columns.

In [None]:
# Paths (MODIFY THESE)
CLEANED_OMOP_PATH = "/path/to/your/cleaned/omop/data/"
CONCEPT_PATH = "/path/to/your/raw/omop/data/concept"
OUTPUT_PATH = "/path/to/your/meds/output/"

# Tables and required columns
OMOP_TABLES = {
    "person": [
        "PERSON_ID", "YEAR_OF_BIRTH", "MONTH_OF_BIRTH", 
        "DAY_OF_BIRTH", "BIRTH_DATETIME"
    ],
    "visit_occurrence": [
        "PERSON_ID", "VISIT_CONCEPT_ID", "VISIT_OCCURRENCE_ID", 
        "VISIT_START_DATETIME"
    ],
    "condition_occurrence": [
        "PERSON_ID", "CONDITION_CONCEPT_ID", 
        "CONDITION_START_DATETIME", "VISIT_OCCURRENCE_ID"
    ],
    "drug_exposure": [
        "PERSON_ID", "DRUG_CONCEPT_ID", 
        "DRUG_EXPOSURE_START_DATETIME", "VISIT_OCCURRENCE_ID"
    ],
    "procedure_occurrence": [
        "PERSON_ID", "PROCEDURE_CONCEPT_ID", 
        "PROCEDURE_DATETIME", "VISIT_OCCURRENCE_ID"
    ],
    "measurement": [
        "PERSON_ID", "MEASUREMENT_CONCEPT_ID", "MEASUREMENT_DATETIME", 
        "VALUE_AS_NUMBER", "VISIT_OCCURRENCE_ID"
    ]
}

print(f"Configured to process {len(OMOP_TABLES)} tables")

## 4. Extract: Load OMOP Data

Load cleaned OMOP tables and the concept vocabulary for code mapping.

In [None]:
def load_omop_tables() -> Dict[str, DataFrame]:
    """Load OMOP tables from Parquet files."""
    tables = {}
    
    for table_name, columns in OMOP_TABLES.items():
        try:
            df = spark.read.parquet(f"{CLEANED_OMOP_PATH}{table_name}")
            tables[table_name] = df.select(*columns)
            logger.info(f"Loaded {table_name}: {df.count():,} records")
        except Exception as e:
            logger.error(f"Error loading {table_name}: {str(e)}")
    
    return tables

print("Loading OMOP tables...")
omop_tables = load_omop_tables()
print(f"\nLoaded {len(omop_tables)} tables")

## 5. Create Concept Mappings

Map OMOP concept IDs to VOCABULARY/CODE format used in MEDS.

Example: Concept ID 313217 → "SNOMED/38341003"

In [None]:
def create_concept_map() -> Dict[str, str]:
    """Create mapping from concept IDs to vocabulary codes."""
    concept_df = spark.read.parquet(CONCEPT_PATH)
    
    # Select required columns
    concept_df = concept_df.select(
        "CONCEPT_ID", 
        "VOCABULARY_ID", 
        "CONCEPT_CODE"
    )
    
    # Create code in format: VOCABULARY/CODE
    concept_df = concept_df.withColumn(
        "code",
        F.concat_ws("/", F.col("VOCABULARY_ID"), F.col("CONCEPT_CODE"))
    )
    
    # Convert to dictionary
    concept_map = dict(
        concept_df.select("CONCEPT_ID", "code")
                  .rdd
                  .collectAsMap()
    )
    
    logger.info(f"Created concept map with {len(concept_map):,} entries")
    return concept_map

print("Creating concept mappings...")
concept_map = create_concept_map()
print(f"Concept map created with {len(concept_map):,} mappings")

## 6. Transform: Person Table

Convert person demographics to MEDS birth events.

### Transformation Logic:
- Construct birth date from YEAR/MONTH/DAY components
- Assign birth event code: SNOMED/184099003 (Live birth)
- Set time = birth date
- Set datetime_value = BIRTH_DATETIME if available

In [None]:
def transform_person(person_df: DataFrame) -> DataFrame:
    """Transform person table to MEDS birth events."""
    # Filter valid records
    df = person_df.filter(F.col("YEAR_OF_BIRTH").isNotNull())
    
    # Create birth datetime from components
    df = df.withColumn(
        "time",
        F.to_timestamp(
            F.concat_ws(
                '-',
                F.col("YEAR_OF_BIRTH"),
                F.lpad(F.coalesce(F.col("MONTH_OF_BIRTH"), F.lit(1)), 2, '0'),
                F.lpad(F.coalesce(F.col("DAY_OF_BIRTH"), F.lit(1)), 2, '0')
            ),
            "yyyy-MM-dd"
        )
    )
    
    # Set datetime_value
    df = df.withColumn("datetime_value", F.col("BIRTH_DATETIME").cast("timestamp"))
    
    # Birth event code (SNOMED: Live birth)
    df = df.withColumn("code", F.lit("SNOMED/184099003"))
    
    # No numeric value for birth
    df = df.withColumn("numeric_value", F.lit(None).cast("float"))
    
    # Select MEDS columns
    return df.select("PERSON_ID", "time", "datetime_value", "code", "numeric_value")

print("Transforming person table...")
person_meds = transform_person(omop_tables['person'])
print(f"Transformed person: {person_meds.count():,} birth events")

## 7. Transform: Clinical Event Tables

Convert clinical events (conditions, drugs, procedures, measurements) to MEDS format.

### Key Transformations:

1. **Time Alignment**: Join with visit_occurrence to get visit start time
   - `time` = visit start datetime (for temporal alignment)
   - `datetime_value` = original event datetime (for reference)

2. **Code Mapping**: Map concept IDs to vocabulary codes
   - Look up concept ID in concept_map
   - Convert to VOCABULARY/CODE format

3. **Numeric Values**: Extract measurement values
   - For measurement table: use VALUE_AS_NUMBER
   - For other tables: set to null

In [None]:
def transform_clinical_table(
    df: DataFrame,
    table_name: str,
    visit_df: DataFrame,
    concept_map_dict: Dict[str, str]
) -> DataFrame:
    """Transform clinical event table to MEDS format."""
    
    # Join with visit to get visit start time
    if "VISIT_OCCURRENCE_ID" in df.columns:
        visit_times = visit_df.select(
            "VISIT_OCCURRENCE_ID",
            F.col("VISIT_START_DATETIME").alias("visit_start_time")
        )
        df = df.join(visit_times, on="VISIT_OCCURRENCE_ID", how="left")
        df = df.withColumn("time", F.col("visit_start_time").cast("timestamp"))
    else:
        df = df.withColumn("time", F.lit(None).cast("timestamp"))
    
    # Get original datetime
    datetime_cols = [col for col in df.columns if "_DATETIME" in col]
    if datetime_cols:
        datetime_value = F.col(datetime_cols[0]).cast("timestamp")
        for col_name in datetime_cols[1:]:
            datetime_value = F.coalesce(datetime_value, F.col(col_name).cast("timestamp"))
        df = df.withColumn("datetime_value", datetime_value)
    else:
        df = df.withColumn("datetime_value", F.lit(None).cast("timestamp"))
    
    # Map concept IDs to codes
    concept_col = next(col for col in df.columns if "CONCEPT_ID" in col and col != "VISIT_OCCURRENCE_ID")
    
    concept_map_df = spark.createDataFrame(
        [(k, v) for k, v in concept_map_dict.items()],
        ["concept_id", "code"]
    )
    
    df = df.join(
        F.broadcast(concept_map_df),
        df[concept_col] == concept_map_df["concept_id"],
        "left"
    )
    
    # Handle numeric values
    if table_name == "measurement":
        df = df.withColumn("numeric_value", F.col("VALUE_AS_NUMBER").cast("float"))
    else:
        df = df.withColumn("numeric_value", F.lit(None).cast("float"))
    
    # Select MEDS columns
    return df.select("PERSON_ID", "time", "datetime_value", "code", "numeric_value")

print("\nTransforming clinical event tables...")
print("=" * 80)

clinical_meds = {}
clinical_tables = [
    'visit_occurrence', 
    'condition_occurrence', 
    'drug_exposure',
    'procedure_occurrence', 
    'measurement'
]

for table_name in clinical_tables:
    if table_name in omop_tables:
        # Handle visit_occurrence separately (no visit join needed)
        if table_name == 'visit_occurrence':
            df = omop_tables[table_name]
            df = df.withColumn("time", F.col("VISIT_START_DATETIME").cast("timestamp"))
            df = df.withColumn("datetime_value", F.col("VISIT_START_DATETIME").cast("timestamp"))
            
            # Map visit concept
            concept_map_df = spark.createDataFrame(
                [(k, v) for k, v in concept_map.items()],
                ["concept_id", "code"]
            )
            df = df.join(
                F.broadcast(concept_map_df),
                df["VISIT_CONCEPT_ID"] == concept_map_df["concept_id"],
                "left"
            )
            df = df.withColumn("numeric_value", F.lit(None).cast("float"))
            clinical_meds[table_name] = df.select("PERSON_ID", "time", "datetime_value", "code", "numeric_value")
        else:
            clinical_meds[table_name] = transform_clinical_table(
                omop_tables[table_name],
                table_name,
                omop_tables['visit_occurrence'],
                concept_map
            )
        
        count = clinical_meds[table_name].count()
        logger.info(f"Transformed {table_name}: {count:,} events")

print(f"\nTransformed {len(clinical_meds)} clinical tables")

## 8. Combine All Events

Union all transformed tables into a single MEDS dataset.

In [None]:
print("\nCombining all events into MEDS format...")

# Collect all DataFrames
all_events = [person_meds] + list(clinical_meds.values())

# Union all events
from functools import reduce
meds_df = reduce(DataFrame.union, all_events)

initial_count = meds_df.count()
logger.info(f"Combined dataset: {initial_count:,} total events")

print(f"\nTotal events before cleaning: {initial_count:,}")

## 9. Temporal Validation and Cleaning

Apply temporal consistency checks to the MEDS data.

### Validation Rules Applied:

1. **Birth Date Validation**: Remove events that occur before patient birth
2. **Datetime Consistency**: Ensure datetime_value is not before time
3. **Temporal Ordering**: Correct datetime values that exceed next event time

These rules ensure the data maintains logical temporal consistency for time-series modeling.

In [None]:
def prepare_person_birthdate() -> DataFrame:
    """Prepare person birthdates for validation."""
    person_df = omop_tables['person']
    
    return person_df.withColumn(
        "birthdate",
        F.to_timestamp(
            F.concat_ws(
                '-',
                F.col("YEAR_OF_BIRTH"),
                F.lpad(F.coalesce(F.col("MONTH_OF_BIRTH"), F.lit(1)), 2, '0'),
                F.lpad(F.coalesce(F.col("DAY_OF_BIRTH"), F.lit(1)), 2, '0')
            ),
            "yyyy-MM-dd"
        )
    ).select("PERSON_ID", "birthdate")

def apply_temporal_validations(df: DataFrame, person_birthdate_df: DataFrame) -> DataFrame:
    """Apply temporal consistency checks."""
    # Join with birthdates
    df = df.join(person_birthdate_df, on="PERSON_ID", how="inner")
    
    # Rule 1: Filter events before birth
    before_filter = df.count()
    df = df.filter(F.col("time") >= F.col("birthdate"))
    after_filter = df.count()
    logger.info(f"Removed {before_filter - after_filter:,} events before birth")
    
    # Rule 2: Ensure datetime_value is not before time
    df = df.withColumn(
        "datetime_value",
        F.when(
            F.col("datetime_value") < F.col("time"),
            F.col("time")
        ).otherwise(F.col("datetime_value"))
    )
    
    # Rule 3: Handle temporal ordering between events
    window_spec = Window.partitionBy("PERSON_ID").orderBy("time")
    df = df.withColumn("next_event_time", F.lead("time").over(window_spec))
    
    df = df.withColumn(
        "datetime_value",
        F.when(
            (F.col("next_event_time").isNotNull()) & 
            (F.col("datetime_value") > F.col("next_event_time")),
            F.col("time")
        ).otherwise(F.col("datetime_value"))
    )
    
    # Remove temporary columns
    return df.drop("next_event_time", "birthdate")

print("\nApplying temporal validations...")
print("=" * 80)

person_birthdate = prepare_person_birthdate()
meds_df = apply_temporal_validations(meds_df, person_birthdate)

final_count = meds_df.count()
records_removed = initial_count - final_count
retention_rate = (final_count / initial_count * 100) if initial_count > 0 else 0

print(f"\nTemporal Validation Results:")
print(f"  Events before: {initial_count:,}")
print(f"  Events after: {final_count:,}")
print(f"  Records removed: {records_removed:,}")
print(f"  Retention rate: {retention_rate:.2f}%")

## 10. Export MEDS Data

Save the final MEDS dataset to Parquet format.

In [None]:
print("\nExporting MEDS data...")

# Write to Parquet
output_file = f"{OUTPUT_PATH}meds_data"
meds_df.write.mode("overwrite").parquet(output_file, compression='snappy')

print(f"\n✓ MEDS data exported to: {output_file}")
print(f"  Total events: {final_count:,}")
print(f"  Unique patients: {meds_df.select('PERSON_ID').distinct().count():,}")

## 11. Verify Output

Load and inspect a sample of the MEDS data.

In [None]:
print("\nVerifying MEDS output...")

# Load back the data
verification_df = spark.read.schema(MEDS_SCHEMA).parquet(output_file)

print("\nMEDS Schema:")
verification_df.printSchema()

print("\nSample Records:")
verification_df.show(10, truncate=False)

print("\nData Statistics:")
print(f"  Total events: {verification_df.count():,}")
print(f"  Unique patients: {verification_df.select('PERSON_ID').distinct().count():,}")
print(f"  Unique codes: {verification_df.select('code').distinct().count():,}")
print(f"  Date range: {verification_df.agg(F.min('time')).collect()[0][0]} to {verification_df.agg(F.max('time')).collect()[0][0]}")

---

## Summary of ETL Process

### Transformation Logic

**Person Table → Birth Events**
- Constructed birth date from year/month/day components
- Assigned standardized birth code (SNOMED/184099003)
- Set time = birth date

**Clinical Tables → Medical Events**
- Joined with visit_occurrence for temporal alignment
- Set time = visit start datetime
- Preserved original event datetime in datetime_value
- Mapped concept IDs to VOCABULARY/CODE format
- Extracted numeric values for measurements

### Temporal Alignment Rationale

All clinical events were aligned to their associated visit's start time. This approach:
- Creates consistent temporal ordering
- Simplifies sequential modeling
- Reduces temporal granularity for privacy
- Preserves original timing in datetime_value field

### Validation Rules

Three temporal consistency rules were applied:

1. **Events must occur after birth**: Removed events before patient birth date
2. **Datetime consistency**: Corrected datetime_value when before time
3. **Event ordering**: Adjusted datetime values that exceeded next event time

These rules ensure temporal consistency for time-series analysis.

### Output Format

The MEDS dataset contains:
- Time-series events for each patient
- Standardized medical codes (VOCABULARY/CODE format)
- Temporal alignment to visit times
- Numeric measurements preserved
- Valid temporal ordering

### Key Decisions

**What Was Transformed:**
- OMOP table structure → MEDS time-series format
- Concept IDs → Vocabulary codes
- Event datetimes → Visit-aligned times

**What Was Preserved:**
- Original event datetimes (in datetime_value)
- Numeric measurement values
- All concept mappings
- Patient identifiers

**What Was Removed:**
- Events before patient birth
- Events with invalid temporal ordering
- Records without valid person IDs

---