# Define Fever Cohort

## Setup

### Bucket and CDR

In [9]:
# %pip install polars

In [5]:
# %pip install google-cloud-storage

In [None]:
from google.cloud import bigquery, storage
import pandas as pd
import polars as pl
import os
import subprocess
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import psutil
from concurrent.futures import ThreadPoolExecutor, as_completed

In [6]:
from omop_unifier import Explorer, Mapper, Unifier

In [9]:
from aou_helpers import load_aou_env

# Load config (this also sets os.environ)
env = load_aou_env()

dataset = os.environ['WORKSPACE_CDR']
bucket = os.environ['WORKSPACE_BUCKET']
temp_bucket = os.environ['WORKSPACE_TEMP_BUCKET']
gcproject = os.environ['GOOGLE_CLOUD_PROJECT']

# Verify variables are accessible via os.environ
print("Environment variables loaded successfully:")
print(f"  dataset: {dataset}")
print(f"  bucket: {bucket}")
print(f"  temp_bucket: {temp_bucket}")

Environment variables loaded successfully:
  dataset: wb-silky-artichoke-2408.C2024Q3R8
  bucket: gs://workspace-bucket-wb-jaunty-date-2788
  temp_bucket: gs://temporary-workspace-bucket-wb-jaunty-date-2788


In [10]:
# This line allows for the plots to be displayed inline in the Jupyter notebook
%matplotlib inline

sns.set(style="ticks",font_scale=1)

In [11]:
pl.Config.set_fmt_str_lengths(128)

# Set the row limit to a higher value
pl.Config.set_tbl_rows(50)

# show all columns in pandas
pd.set_option("display.max_columns", None)

# show full column width
pd.set_option('display.max_colwidth', 100)

# Function

In [27]:
def print_resource_usage():
  """Monitor memory and CPU usage"""
  memory = psutil.virtual_memory()
  cpu_percent = psutil.cpu_percent(interval=1)

  print("\n")
  print("="*50)
  print(f"Memory: {memory.used / 1e9:.1f}GB / {memory.total / 1e9:.1f}GB "
        f"({memory.percent:.1f}% used.) CPU: {cpu_percent:.1f}% used across {os.cpu_count()} cores.")

In [13]:
def polars_gbq(query):
    """Execute BigQuery SQL and return polars DataFrame"""
    client = bigquery.Client()
    return pl.from_arrow(client.query(query).result().to_arrow())

# Get Total Cohort

In [16]:
# Vitals exclusion concept IDs (from sarcoid project)
# These are measurements collected on all AoU participants (not indicative of actual clinical measurements)
# Heart rate rhythm, Heart rate, Blood pressure panel, Adult Waist Circumference Protocol,
# PhenX - hip circumference protocol, Body height, Body weight, Diastolic BP, Systolic BP, BMI, Pre-pregnancy weight
VITALS_EXCLUSION = [3022318, 3027018, 3031203, 40759207, 40765148, 3036277,
                    3025315, 3012888, 3004249, 3038553, 3022281]

In [28]:
# Count Participants with Relevant EHR Data
# Requirement: Must have data from ALL modalities (vitals AND meds AND ICD AND labs)

# Participants with medications
meds_query = f"""
SELECT DISTINCT person_id
FROM `{dataset}`.drug_exposure AS de
JOIN `{dataset}`.concept AS c
    ON de.drug_concept_id = c.concept_id
WHERE c.domain_id = 'Drug'
"""
meds_persons = polars_gbq(meds_query)
print(f"Participants with medications: {len(meds_persons):,}")

# Participants with ICD codes
# Checks both observation and condition_occurrence tables
# Checks both source_value and source_concept_id paths
icd_query = f"""
SELECT DISTINCT person_id FROM (
  -- Codes from observation (source_value)
  SELECT o.person_id
  FROM `{dataset}`.observation AS o
  JOIN `{dataset}`.concept AS c ON o.observation_source_value = c.concept_code
  WHERE c.vocabulary_id IN ('ICD9CM', 'ICD10CM', 'SNOMED')

  UNION ALL

  -- Codes from observation (source_concept_id)
  SELECT o.person_id
  FROM `{dataset}`.observation AS o
  JOIN `{dataset}`.concept AS c ON o.observation_source_concept_id = c.concept_id
  WHERE c.vocabulary_id IN ('ICD9CM', 'ICD10CM', 'SNOMED')

  UNION ALL

  -- Codes from condition_occurrence (source_value)
  SELECT co.person_id
  FROM `{dataset}`.condition_occurrence AS co
  JOIN `{dataset}`.concept AS c ON co.condition_source_value = c.concept_code
  WHERE c.vocabulary_id IN ('ICD9CM', 'ICD10CM', 'SNOMED')

  UNION ALL

  -- Codes from condition_occurrence (source_concept_id)
  SELECT co.person_id
  FROM `{dataset}`.condition_occurrence AS co
  JOIN `{dataset}`.concept AS c ON co.condition_source_concept_id = c.concept_id
  WHERE c.vocabulary_id IN ('ICD9CM', 'ICD10CM', 'SNOMED')
)
"""
icd_persons = polars_gbq(icd_query)
print(f"Participants with ICD codes: {len(icd_persons):,}")

# Participants with lab measurements (LOINC codes excluding registration vitals)
vitals_exclusion_str = ', '.join(map(str, VITALS_EXCLUSION))
labs_query = f"""
SELECT DISTINCT m.person_id
FROM `{dataset}`.measurement AS m
JOIN `{dataset}`.concept AS c ON m.measurement_concept_id = c.concept_id
WHERE c.vocabulary_id = 'LOINC'
  AND m.measurement_concept_id NOT IN ({vitals_exclusion_str})
"""
labs_persons = polars_gbq(labs_query)
print(f"Participants with lab measurements: {len(labs_persons):,}")
print()

# Intersection: Participants with ALL modalities
all_modalities = meds_persons.join(icd_persons, on='person_id', how='inner')
all_modalities = all_modalities.join(labs_persons, on='person_id', how='inner')

n_with_all_ehr = len(all_modalities)
def format_count(n):
    return '<20' if n < 20 else f"{n:,}"

print(f"N = {format_count(n_with_all_ehr)} participants with data from ALL modalities")
print(f"  (vitals AND medications AND ICD codes AND labs)")

print_resource_usage()

Participants with medications: 334,643
Participants with ICD codes: 356,363
Participants with lab measurements: 348,126

N = 315,945 participants with data from ALL modalities
  (vitals AND medications AND ICD codes AND labs)


Memory: 3.2GB / 135.2GB (2.4% used.) CPU: 0.0% used across 32 cores.


## Step 3: Construct Macrovisits Using N3C Algorithm

In [29]:
# Visit concept IDs for macrovisit construction
# IP, Inpatient Hospital, ER+Inpatient, ICU, Inpatient Critical Care
INPATIENT_CONCEPTS = [9201, 8717, 262, 32037, 581379]
OUTPATIENT_CONCEPT = 9202
ER_CONCEPT = 9203

visit_concepts_query = f"""
      SELECT
          v.visit_concept_id,
          c.concept_name AS visit_type,
          COUNT(DISTINCT v.person_id) AS person_count,
          COUNT(*) AS visit_count,
          CASE
              WHEN ca_9201.ancestor_concept_id IS NOT NULL THEN '9201 (IP)'
              WHEN ca_9202.ancestor_concept_id IS NOT NULL THEN '9202 (OP)'
              WHEN ca_9203.ancestor_concept_id IS NOT NULL THEN '9203 (ER)'
              ELSE 'Other'
          END AS visit_ancestor
      FROM `{dataset}.visit_occurrence` v
      JOIN `{dataset}.concept` c
          ON v.visit_concept_id = c.concept_id
      LEFT JOIN `{dataset}.concept_ancestor` ca_9201
          ON v.visit_concept_id = ca_9201.descendant_concept_id
          AND ca_9201.ancestor_concept_id IN (9201, 8717, 262, 32037, 581379)
      LEFT JOIN `{dataset}.concept_ancestor` ca_9202
          ON v.visit_concept_id = ca_9202.descendant_concept_id
          AND ca_9202.ancestor_concept_id = 9202
      LEFT JOIN `{dataset}.concept_ancestor` ca_9203
          ON v.visit_concept_id = ca_9203.descendant_concept_id
          AND ca_9203.ancestor_concept_id = 9203
      GROUP BY v.visit_concept_id, c.concept_name, visit_ancestor
      ORDER BY visit_count DESC
      """

visit_descendents_df = polars_gbq(visit_concepts_query)

In [31]:
visit_descendents_df.head()

visit_concept_id,visit_type,person_count,visit_count,visit_ancestor
i64,str,i64,i64,str
9202,"""Outpatient Visit""",514097,38997910,"""9202 (OP)"""
581477,"""Office Visit""",136227,7550613,"""9202 (OP)"""
32036,"""Laboratory Visit""",115638,3900001,"""Other"""
5083,"""Telehealth""",73101,2430598,"""Other"""
9201,"""Inpatient Visit""",135018,2090555,"""9201 (IP)"""


In [None]:
# Instead of one massive SQL query, break into smaller steps
# Step 1: Get visits for people with all modalities (in batches to avoid query size limits)

print("Step 1: Loading visit_occurrence records for relevant participants...")
print(f"  - Total participants to query: {all_modalities['person_id'].n_unique():,}")

# Batch the person_ids to avoid BigQuery query size limit (1MB)
# Each person_id is ~8 chars, so 50K IDs = ~450KB per query (safe margin)
batch_size = 50000
person_ids = all_modalities['person_id'].to_list()

# Create batches
batches = []
for i in range(0, len(person_ids), batch_size):
    batch_ids = person_ids[i:i+batch_size]
    batches.append(batch_ids)

total_batches = len(batches)
print(f"  - Created {total_batches} batches of ~{batch_size:,} IDs each")

# Function to execute a single batch query
def fetch_batch(batch_idx, batch_ids):
    """Fetch visits for a batch of person_ids"""
    person_ids_str = ','.join(map(str, batch_ids))
    
    batch_query = f"""
    SELECT 
        person_id,
        CAST(visit_occurrence_id AS STRING) as visit_occurrence_id,
        visit_concept_id,
        visit_start_date,
        visit_end_date
    FROM `{dataset}`.visit_occurrence
    WHERE person_id IN ({person_ids_str})
      AND visit_start_date IS NOT NULL 
      AND visit_end_date IS NOT NULL
      AND visit_start_date <= visit_end_date
    """
    
    batch_visits = polars_gbq(batch_query)
    return batch_idx, batch_visits, len(batch_ids)

# Execute batches in parallel using ThreadPoolExecutor
# Use max_workers based on available CPUs (limit to reasonable number for BigQuery)
max_workers = min(16, os.cpu_count())  # Limit to 16 parallel queries to avoid overwhelming BigQuery
print(f"  - Running {total_batches} batches in parallel with {max_workers} workers...")

all_visits_batches = [None] * total_batches  # Pre-allocate list to maintain order

with ThreadPoolExecutor(max_workers=max_workers) as executor:
    # Submit all batch queries
    future_to_batch = {
        executor.submit(fetch_batch, i, batch): i 
        for i, batch in enumerate(batches)
    }
    
    # Collect results as they complete
    completed = 0
    for future in as_completed(future_to_batch):
        batch_idx, batch_visits, batch_size_actual = future.result()
        all_visits_batches[batch_idx] = batch_visits
        completed += 1
        print(f"  - Completed {completed}/{total_batches}: Batch {batch_idx+1} - {len(batch_visits):,} visits for {batch_size_actual:,} participants")

# Concatenate all batches
all_visits = pl.concat(all_visits_batches)
print(f"\n✓ Total visits loaded: {len(all_visits):,} for {all_modalities['person_id'].n_unique():,} participants")
print_resource_usage()

In [None]:
all_visits.head()

In [None]:
# Step 2: Get concept_ancestor relationships for visit types
print("Step 2: Getting visit type hierarchies...")

# Create comma-separated string of inpatient concept IDs
inpatient_concepts_str = ', '.join(map(str, INPATIENT_CONCEPTS))

visit_ancestor_query = f"""
SELECT DISTINCT
    ca.descendant_concept_id,
    ca.ancestor_concept_id
FROM `{dataset}`.concept_ancestor ca
WHERE ca.ancestor_concept_id IN ({inpatient_concepts_str}, {OUTPATIENT_CONCEPT}, {ER_CONCEPT})
"""

visit_ancestors = polars_gbq(visit_ancestor_query)
print(f"Loaded {len(visit_ancestors):,} concept relationships")

# Create lookup sets for each visit type
inpatient_descendants = set(
    visit_ancestors.filter(pl.col('ancestor_concept_id').is_in(INPATIENT_CONCEPTS))['descendant_concept_id'].to_list()
)
outpatient_descendants = set(
    visit_ancestors.filter(pl.col('ancestor_concept_id') == OUTPATIENT_CONCEPT)['descendant_concept_id'].to_list()
)
er_descendants = set(
    visit_ancestors.filter(pl.col('ancestor_concept_id') == ER_CONCEPT)['descendant_concept_id'].to_list()
)

print(f"  - Inpatient descendants: {len(inpatient_descendants)}")
print(f"  - Outpatient descendants: {len(outpatient_descendants)}")
print(f"  - ER descendants: {len(er_descendants)}")

In [None]:
# Step 3: Identify qualifying microvisits in Polars
print("Step 3: Filtering to qualifying microvisits...")

all_visits = all_visits.with_columns([
    ((pl.col('visit_end_date') - pl.col('visit_start_date')).dt.total_days()).alias('los_days')
])

inpatient_microvisits = all_visits.filter(
    # IP concepts
    pl.col('visit_concept_id').is_in(list(inpatient_descendants))
    # OR ER visits with LOS >= 1 day
    | (pl.col('visit_concept_id').is_in(list(er_descendants)) & (pl.col('los_days') >= 1))
    # OR OP visits with LOS = 1 day
    | (pl.col('visit_concept_id').is_in(list(outpatient_descendants)) & (pl.col('los_days') == 1))
)

print(f"Qualifying microvisits: {len(inpatient_microvisits):,} ({len(inpatient_microvisits)/len(all_visits):.2%})")
print(f"  - Unique persons: {inpatient_microvisits['person_id'].n_unique():,} ({inpatient_microvisits['person_id'].n_unique()/all_visits['person_id'].n_unique():.2%})")
print_resource_usage()

In [None]:
# Step 4: N3C Merging Intervals Algorithm (SQL CTEs adapted in Polars)

print("Step 4: Running N3C merging intervals algorithm...")

# Step 4a: Add max_end_date and rank
merging_a = inpatient_microvisits.sort(['person_id', 'visit_start_date']).with_columns([
    pl.col('visit_end_date').cum_max().over('person_id').alias('max_end_date'),
    pl.col('visit_start_date').rank().over('person_id').alias('rank_value')
])

print(f"\n  - Added max_end_date and rank:")
display(merging_a.head())

# Step 4b: Identify gaps (new macrovisit starts)
merging_b = merging_a.with_columns([
    pl.when(
        pl.col('visit_start_date') <= pl.col('max_end_date').shift(1).over('person_id')
    ).then(0).otherwise(1).alias('gap')
])

print(f"\n  - Identified gaps between visits:")
display(merging_b.head())

# Step 4c: Create group numbers (cumulative sum of gaps)
merging_c = merging_b.with_columns([
    pl.col('gap').cum_sum().over('person_id').alias('group_number')
])

print(f"\n  - Assigned group numbers:")
display(merging_c.head())

# Step 4d: Create macrovisit_id and date ranges
macrovisits = merging_c.group_by(['person_id', 'group_number']).agg([
    pl.col('visit_start_date').min().alias('macrovisit_start_date'),
    pl.col('visit_end_date').max().alias('macrovisit_end_date')
]).with_columns([
    # Create macrovisit_id matching N3C format
    (
        pl.col('person_id').cast(pl.Utf8) + '_' +
        pl.col('group_number').cast(pl.Utf8) + '_' +
        (pl.col('macrovisit_start_date').hash() % 10).abs().cast(pl.Utf8)
    ).alias('macrovisit_id')
])

print(f"\n  - Created {len(macrovisits):,} macrovisits from {len(inpatient_microvisits):,} microvisits ({len(macrovisits)/len(inpatient_microvisits):.2%})")
print(f"  - Unique persons with macrovisits: {macrovisits['person_id'].n_unique():,}\n")
display(macrovisits.filter(pl.col('person_id')==1000039).head())

print_resource_usage()

In [None]:
# Step 5: Map all visits to macrovisits
print("Step 5: Mapping all visits to macrovisits...")

# Join all visits to macrovisits where visit_start_date falls within macrovisit window
microvisits_to_macrovisits = all_visits.join(
    macrovisits,
    on='person_id',
    how='left'
).filter(
    # Keep visits that fall within a macrovisit OR have no macrovisit
    pl.col('macrovisit_id').is_null() |
    (
        (pl.col('visit_start_date') >= pl.col('macrovisit_start_date')) &
        (pl.col('visit_start_date') <= pl.col('macrovisit_end_date'))
    )
)

print(f"Mapped {len(microvisits_to_macrovisits):,} visits")
print(f"  - Visits with macrovisit_id: {microvisits_to_macrovisits.filter(pl.col('macrovisit_id').is_not_null()).height:,} ({microvisits_to_macrovisits.filter(pl.col('macrovisit_id').is_not_null()).height/len(microvisits_to_macrovisits):.2%})")
display(microvisits_to_macrovisits.filter(pl.col('macrovisit_id').is_not_null()).head())

print_resource_usage()

In [None]:
# Step 6: Filter to macrovisits with IP/ER anchor
print("Step 6: Filtering to macrovisits with IP or ER anchor...")

# Identify macrovisits that have at least one IP or ER visit
ip_er_macrovisits = microvisits_to_macrovisits.filter(
    pl.col('macrovisit_id').is_not_null() &
    (
        pl.col('visit_concept_id').is_in(list(inpatient_descendants)) |
        pl.col('visit_concept_id').is_in(list(er_descendants))
    )
).select('macrovisit_id').unique()

print(f"Macrovisits with IP/ER anchor: {len(ip_er_macrovisits):,}")

# Keep only visits that belong to these macrovisits (or have no macrovisit)
microvisits_to_macrovisits = microvisits_to_macrovisits.filter(
    pl.col('macrovisit_id').is_null() |
    pl.col('macrovisit_id').is_in(ip_er_macrovisits['macrovisit_id'].implode())
)

print(f"\nFinal result:")
print(f"  - Total visits (macrovisits): {len(microvisits_to_macrovisits):,}")
print(f"  - Visits with macrovisit_id: {microvisits_to_macrovisits.filter(pl.col('macrovisit_id').is_not_null()).height:,}")
print(f"  - Unique macrovisits: {microvisits_to_macrovisits['macrovisit_id'].n_unique():,}")
print(f"  - Unique persons with macrovisits: {microvisits_to_macrovisits.filter(pl.col('macrovisit_id').is_not_null())['person_id'].n_unique():,}")

print_resource_usage()

## Step 4: Extract Macrovisits Summary Table

In [None]:
# Create macrovisits summary from microvisits mapping
ip_er_macrovisits_df = microvisits_to_macrovisits.filter(
    pl.col('macrovisit_id').is_not_null()
).group_by(['person_id', 'macrovisit_id', 'macrovisit_start_date', 'macrovisit_end_date']).agg([
    pl.len().alias('n_microvisits')
])

print(f"Macrovisits summary:")
print(f"  - Total macrovisits: {len(ip_er_macrovisits_df):,}")
print(f"  - Median microvisits per macrovisit: {ip_er_macrovisits_df['n_microvisits'].median():.0f}")

display(ip_er_macrovisits_df.sort('person_id', 'macrovisit_start_date').head())

In [None]:
ip_er_macrovisits_df.write_csv(f'{bucket}/cohorts/ip_er_macrovisits.tsv', separator='\t')

## Step 5: Validation Checks

In [None]:
print("Running validation checks...\n")

# Check 1: No overlapping macrovisits
overlapping = ip_er_macrovisits_df.join(
    ip_er_macrovisits_df,
    on='person_id',
    suffix='_other'
).filter(
    (pl.col('macrovisit_id') != pl.col('macrovisit_id_other')) &
    (pl.col('macrovisit_start_date') <= pl.col('macrovisit_end_date_other')) &
    (pl.col('macrovisit_start_date_other') <= pl.col('macrovisit_end_date'))
)

print(f"Check 1 - Overlapping macrovisits: {len(overlapping)} (expected: 0)")
if len(overlapping) > 0:
    print("WARNING: Found overlapping macrovisits!")
    print(overlapping.head())

In [None]:
# Check 2: Each microvisit maps to ≤1 macrovisit
multi_mapped = microvisits_to_macrovisits.filter(
    pl.col('macrovisit_id').is_not_null()
).group_by('visit_occurrence_id').agg([
    pl.col('macrovisit_id').n_unique().alias('n_macrovisits')
]).filter(pl.col('n_macrovisits') > 1)

print(f"Check 2 - Microvisits mapped to >1 macrovisit: {len(multi_mapped)} (expected: 0)")
if len(multi_mapped) > 0:
    print("WARNING: Found microvisits mapped to multiple macrovisits!")
    print(multi_mapped.head())

In [None]:
# Check 3: Microvisit start dates within macrovisit boundaries
out_of_bounds = microvisits_to_macrovisits.filter(
    pl.col('macrovisit_id').is_not_null()
).filter(
    (pl.col('visit_start_date') < pl.col('macrovisit_start_date')) |
    (pl.col('visit_start_date') > pl.col('macrovisit_end_date'))
)

print(f"Check 3 - Microvisits outside macrovisit boundaries: {len(out_of_bounds)} (expected: 0)")
if len(out_of_bounds) > 0:
    print("WARNING: Found microvisits outside macrovisit date boundaries!")
    print(out_of_bounds.head())

In [None]:
# Check 4: Confirm all macrovisits have IP/ER anchor (should be guaranteed by filtering)
# Use descendant concept IDs, not just core concepts
ip_er_descendants_combined = inpatient_descendants.union(er_descendants)

missing_anchor = microvisits_to_macrovisits.filter(
    pl.col('macrovisit_id').is_not_null()
).group_by('macrovisit_id').agg([
    pl.col('visit_concept_id').is_in(list(ip_er_descendants_combined)).sum().alias('n_ip_er')
]).filter(pl.col('n_ip_er') == 0)

print(f"Check 4 - Macrovisits without IP/ER anchor: {len(missing_anchor)} (expected: 0)")
if len(missing_anchor) > 0:
    print("ERROR: This should not happen - filtering logic may be broken!")
    print(missing_anchor.head())

## Step 6: Generate Validation Visualizations

In [None]:
# Calculate LOS in Polars
macrovisits_with_los = ip_er_macrovisits_df.with_columns([
    (pl.col('macrovisit_end_date') - pl.col('macrovisit_start_date')).dt.total_days().alias('los_days')
])

los_summary = macrovisits_with_los.group_by('los_days').agg([
    pl.len().alias('n')
]).sort('los_days')

# Apply <20 suppression for display
los_plot_data = los_summary.filter(pl.col('los_days') <= 100)  # Focus on reasonable LOS

fig, ax = plt.subplots(figsize=(12, 6))
# Filter to counts >=20 for plotting
plottable = los_plot_data.filter(pl.col('n') >= 20)
ax.bar(plottable['los_days'], plottable['n'])
ax.set_yscale('log')
ax.set_xlabel('Length of Stay (days)')
ax.set_ylabel('Number of Macrovisits')
ax.set_title('Macrovisit Length of Stay Distribution (counts <20 suppressed)')
plt.tight_layout()
plt.show()

print(f"Total macrovisits: {len(ip_er_macrovisits_df):,}")
print(f"Median LOS: {macrovisits_with_los['los_days'].median():.1f} days")
print(f"Mean LOS: {macrovisits_with_los['los_days'].mean():.1f} days")
print()

In [None]:
# Microvisit count per macrovisit
microvisit_counts = ip_er_macrovisits_df.select(['macrovisit_id', 'n_microvisits'])

count_dist = microvisit_counts.group_by('n_microvisits').agg([
    pl.len().alias('n_macrovisits')
]).sort('n_microvisits')

fig, ax = plt.subplots(figsize=(10, 6))
# Only plot non-suppressed counts (>=20)
plottable = count_dist.filter(pl.col('n_macrovisits') >= 20)
if len(plottable) > 0:
    ax.bar(plottable['n_microvisits'], plottable['n_macrovisits'])
    ax.set_yscale('log')
    ax.set_xlabel('Number of Microvisits per Macrovisit')
    ax.set_ylabel('Number of Macrovisits')
    ax.set_title('Microvisits per Macrovisit Distribution (counts <20 suppressed)')
else:
    ax.text(0.5, 0.5, 'All counts <20 (suppressed per AoU policy)', 
            ha='center', va='center', transform=ax.transAxes)
    ax.set_title('Microvisits per Macrovisit Distribution')
plt.tight_layout()
plt.show()

print(f"Median microvisits per macrovisit: {microvisit_counts['n_microvisits'].median()}")
print(f"Max microvisits in a single macrovisit: {microvisit_counts['n_microvisits'].max()}")

## Step 7: Extract and Unify Temperature Measurements

In [None]:
# Global Variables for Unifier
cohort = "allofus"
version = dataset.split('.')[1]
file_type = "parquet" # csv also possible
file_path = f"{bucket}/data/{cohort}/{version}"

In [None]:
# Initialize the explorer object 
explorer = Explorer(variable_type="measurements", cohort=cohort, version=version)

# Display cohorts
explorer.cohorts()

# Display annotated variables
metadata = explorer.variables()
print(f"Loaded {len(metadata)} annotated measurement variables")

In [None]:
# Query for temperature concepts in LOINC
temp_concept_query = f"""
WITH filtered_concepts AS (
    SELECT DISTINCT concept_id, concept_code, concept_name
    FROM {dataset}.concept c
    WHERE vocabulary_id = 'LOINC'
        AND LOWER(concept_name) LIKE '%temperature%'
        AND LOWER(concept_name) NOT LIKE '%environmental%'
)
SELECT concept_name, concept_id, COUNT(measurement_id) AS count
FROM {dataset}.measurement m
LEFT JOIN filtered_concepts fc ON fc.concept_id = m.measurement_concept_id
WHERE measurement_concept_id IN (SELECT concept_id FROM filtered_concepts) 
GROUP BY concept_name, concept_id
ORDER BY count DESC
LIMIT 50
"""

polars_gbq(temp_concept_query)

In [None]:
measurement_cid = 3020891                    # OMOP measurement_concept_id 

test_concept_query = f"""
SELECT DISTINCT
    c.concept_name as lab_concept_name,
    c.concept_id as lab_concept_id, 
    c.concept_code as lab_concept_code,
    c.vocabulary_id as lab_vocab,
    c1.concept_name as standard_unit,
    count(measurement_id) AS count
FROM {dataset}.measurement m
JOIN {dataset}.concept c ON m.measurement_concept_id = c.concept_id
LEFT JOIN {dataset}.concept c1 ON m.unit_concept_id = c1.concept_id
WHERE c.concept_id = {measurement_cid}
GROUP BY lab_concept_name, lab_concept_id, lab_concept_code, lab_vocab, standard_unit
ORDER BY count DESC
LIMIT 10
"""

polars_gbq(test_concept_query)

In [None]:
temperature_concepts = {
    "Body temperature":	[3020891, "degree Celsius"],
    "Oral temperature":	[3006322, "degree Fahrenheit"],
    "Body temperature|Temperature|Moment in time|Without specimen":	[40654162, None],
    "Body temperature - Temporal artery": [46235152, "degree Fahrenheit"],
    # "Skin temperature --in microenvironment": [21490591, "no value"],  --- all missing
    "Tympanic membrane temperature": [3025163, "degree Fahrenheit"],
    # "Blood temperature": [21490586, None],  --- integers
    "Axillary temperature":	[3025085, "degree Fahrenheit"],
    "Body temperature - Core": [3025926, "degree Fahrenheit"],
    "Temperature of Skin": [3039856, "degree Fahrenheit"],
    "Rectal temperature": [3022060, "degree Fahrenheit"],
    "Esophageal temperature": [21490588, "degree Fahrenheit"],
}

In [None]:
# Initialize Mapper for temperature measurements
# Based on the unit distribution above, set the standard unit
# ============================= CYCLE STARTING HERE
measurement_name, (measurement_cid, standard_unit) = \
    list(temperature_concepts.items())[0]

min_value = 30                            # Minimum physiologic value
max_value = 45                           # Maximum physiologic value
group = "vitals"                          # Grouping of the measurement

units = Mapper(
    measurement_cid=measurement_cid, 
    measurement_name=measurement_name,
    standard_unit=standard_unit,
    min_value=min_value,
    max_value=max_value,
    group=group,
    cohort=cohort,
    version=version,
)

In [None]:
# Display metadata
print("Mapper initialized with the following metadata:")
print(units.metadata)
print_resource_usage()

In [None]:
# Initialize the unit mapping process
units.init_map()

# Display the predominant unit histogram
units.predominant_unit_histogram()

In [None]:
current_unit = units.previous_unit()

In [None]:
current_unit = units.next_unit()    # .previous_unit() if user wants to return to the last unit
units.unit_histogram()              # Display the histogram for the current unit_concept_name


In [None]:
conversion_factor = 1      # conversion factor to convert to the standard unit
multimodal = 0             # 1 if distribution is multimodal or 0 if not

units.unit_update(current_unit, 
                  conversion_factor = conversion_factor, 
                  multimodal = multimodal)

In [None]:
conversion_factor = 0
multimodal = 0

units.unit_update(current_unit, 
                  conversion_factor = conversion_factor, 
                  multimodal = multimodal)

In [None]:
units.save()
# ============================= CYCLE STARTING HERE

## Step 8: Clean Temperature Data

In [None]:
explorer = Explorer(variable_type="measurements", cohort=cohort, version=version)

In [None]:
metadata=explorer.variables()

In [None]:
for measurement_cid in metadata.filter(pl.col('group')=="vitals")['measurement_concept_id']:
    measurement = Unifier(measurement_cid = measurement_cid,
                          save_dir=file_path,
                          cohort=cohort,
                          version=version,
                          file_type=file_type,
                          save_annotation=True,
                          drop_sites=True)
    measurement.unify()

In [None]:
# Load unified temperature data from Unifier output
# The Unifier has already handled unit conversion to standard unit (Celsius)

# Create Unifier instance to load the processed data
unifier = Unifier(
    measurement_cid=measurement_cid,
    measurement_name=measurement_name,
    cohort=cohort,
    version=version
)

# Load the unified temperature data
temp_df_unified = unifier.load_data()

print(f"Loaded unified temperature data: {len(temp_df_unified):,} measurements")
print(f"  - Unique persons: {temp_df_unified['person_id'].n_unique():,}")

# The Unifier should have already:
# 1. Converted all units to the standard unit (Celsius)
# 2. Applied min/max physiological value filtering (32-42°C)
# 3. Removed multimodal/erroneous measurements

# Check the distribution
print("\nTemperature distribution (unified, Celsius):")
print(f"  Min: {temp_df_unified['value_as_number'].min():.1f}°C")
print(f"  1st percentile: {temp_df_unified['value_as_number'].quantile(0.01):.1f}°C")
print(f"  Median: {temp_df_unified['value_as_number'].median():.1f}°C")
print(f"  99th percentile: {temp_df_unified['value_as_number'].quantile(0.99):.1f}°C")
print(f"  Max: {temp_df_unified['value_as_number'].max():.1f}°C")

# Rename column for clarity
temp_df_clean = temp_df_unified.rename({'value_as_number': 'temp_celsius'})

print(f"\n✓ Temperature data unified and ready for analysis")
print(f"  Final N: {len(temp_df_clean):,} measurements")
print(f"  Unique persons: {temp_df_clean['person_id'].n_unique():,}")

print_resource_usage()

## Step 9: Identify Fever Cohort

In [None]:
# Get unique macrovisits with start dates
inpatient_macrovisits = macrovisits_df.select([
    'person_id',
    'macrovisit_id',
    'macrovisit_start_date',
    'macrovisit_end_date'
])

# Convert dates to datetime for comparison
# Add 24 hours to start date to define window
inpatient_macrovisits = inpatient_macrovisits.with_columns([
    pl.col('macrovisit_start_date').cast(pl.Datetime).alias('start_datetime'),
    (pl.col('macrovisit_start_date').cast(pl.Datetime) + pl.duration(hours=24)).alias('end_24hr')
])

# Join temperatures to macrovisits
# Filter to first 24 hours only
first_24hr_temps = temp_df_clean.join(
    inpatient_macrovisits,
    on='person_id',
    how='inner'
).filter(
    (pl.col('measurement_datetime') >= pl.col('start_datetime')) &
    (pl.col('measurement_datetime') < pl.col('end_24hr'))
)

print(f"Temperature measurements in first 24hr of admissions: {len(first_24hr_temps):,}")

# Calculate fever episodes
# Fever definition: ≥2 measurements >38°C
fever_episodes = first_24hr_temps.with_columns([
    (pl.col('temp_celsius') > 38.0).cast(pl.Int32).alias('is_fever')
]).group_by(['person_id', 'macrovisit_id', 'macrovisit_start_date']).agg([
    pl.col('is_fever').sum().alias('fever_measurements'),
    pl.col('temp_celsius').max().alias('peak_temp'),
    pl.col('measurement_datetime').min().alias('first_temp_time'),
    pl.col('measurement_datetime').max().alias('last_temp_time'),
    pl.len().alias('total_measurements')
]).filter(
    pl.col('fever_measurements') >= 2  # At least 2 measurements >38°C
)

print(f"\nFever episodes identified: {len(fever_episodes):,}")
print(f"  - Unique persons with fever: {fever_episodes['person_id'].n_unique():,}")
print(f"\n✓ Fever cohort construction complete!")

In [None]:
# Filter 2: Post-Fever Data or Death Requirement
# Ensure patients either have data after fever OR have a recorded death
# This avoids including patients who simply drop out of the healthcare system

# Get death dates from death table
death_query = f"""
SELECT person_id, death_date
FROM `{dataset}`.death
"""
death_df = polars_gbq(death_query)

print(f"Loaded {len(death_df):,} death records")

# Join with fever episodes to check for death
fever_episodes_with_death = fever_episodes_filtered.join(
    death_df,
    on='person_id',
    how='left'
)

# Get all visit dates for patients in fever cohort (to check for post-fever data)
post_fever_visits_query = f"""
SELECT DISTINCT
    v.person_id,
    v.visit_start_date
FROM `{dataset}`.visit_occurrence v
WHERE v.person_id IN ({','.join(map(str, fever_episodes_filtered['person_id'].unique().to_list()))})
"""

# Note: If person_id list is very large (>10K), may need to use temp table approach
print("Checking for post-fever visits...")
post_fever_visits_df = polars_gbq(post_fever_visits_query)

# Join to identify visits that occur after the fever episode
fever_with_post_visits = fever_episodes_with_death.join(
    post_fever_visits_df,
    on='person_id',
    how='left'
).filter(
    # Post-fever visit = visit after the macrovisit end date
    pl.col('visit_start_date') > pl.col('macrovisit_end_date')
).select(['person_id', 'macrovisit_id']).unique()

print(f"Patients with post-fever visits: {fever_with_post_visits['person_id'].n_unique():,}")

# Create indicator for "has post-fever data or death"
# Use .implode() when checking Polars column membership
fever_episodes_final = fever_episodes_with_death.with_columns([
    # Has post-fever visit
    pl.col('macrovisit_id').is_in(fever_with_post_visits['macrovisit_id'].implode()).alias('has_post_fever_visit'),
    # Has death record
    pl.col('death_date').is_not_null().alias('has_death_record')
]).with_columns([
    # Pass filter if either condition is true
    (pl.col('has_post_fever_visit') | pl.col('has_death_record')).alias('has_followup_data')
])

n_before_data_filter = len(fever_episodes_final)
n_with_followup = fever_episodes_final.filter(pl.col('has_followup_data')).height
n_without_followup = n_before_data_filter - n_with_followup

print(f"\nFilter 2 - Post-Fever Data or Death Requirement:")
print(f"  - Episodes before filter: {n_before_data_filter:,}")
print(f"  - Episodes with post-fever data or death: {n_with_followup:,}")
print(f"  - Excluded (no follow-up data): {n_without_followup:,} ({100*n_without_followup/n_before_data_filter:.1f}%)")
print(f"  - Reason: No healthcare visits after fever AND no death record")
print()

# Apply the filter
fever_episodes_final = fever_episodes_final.filter(pl.col('has_followup_data'))

print(f"✓ Final fever cohort after bias protection filters: {len(fever_episodes_final):,} episodes")
print(f"  - Unique persons: {fever_episodes_final['person_id'].n_unique():,}")

In [None]:
# Filter 1: Longitudinal Data Requirement
# Remove admissions that start < 45 days before end of study (2023-10-01)
# This ensures sufficient potential follow-up time to observe outcomes

study_end_date = pl.date(2023, 10, 1)

fever_episodes_with_followup = fever_episodes.with_columns([
    (study_end_date - pl.col('macrovisit_start_date')).dt.total_days().alias('days_potential_followup')
])

n_before_filter = len(fever_episodes_with_followup)

fever_episodes_filtered = fever_episodes_with_followup.filter(
    pl.col('days_potential_followup') >= 45
)

n_after_followup_filter = len(fever_episodes_filtered)
n_excluded_followup = n_before_filter - n_after_followup_filter

print(f"Filter 1 - Longitudinal Data Requirement:")
print(f"  - Episodes before filter: {n_before_filter:,}")
print(f"  - Episodes after filter (≥45 days potential follow-up): {n_after_followup_filter:,}")
print(f"  - Excluded: {n_excluded_followup:,} ({100*n_excluded_followup/n_before_filter:.1f}%)")
print(f"  - Reason: Admission within 45 days of study end date (insufficient follow-up time)")
print()

## Step 9b: Apply Bias Protection Filters

To protect against various biases (immortal time bias, surveillance bias), we apply additional filters based on N3C RECOVER methodology:

## Step 10: Generate Final Cohort Summary

In [None]:
# Count unique participants and admissions (using filtered cohort)
n_participants = fever_episodes_final['person_id'].n_unique()
n_admissions = len(fever_episodes_final)

# Apply <20 suppression for display
def format_count(n):
    return '<20' if n < 20 else f"{n:,}"

print("=" * 60)
print("FINAL FEVER COHORT SUMMARY")
print("=" * 60)
print(f"N = {format_count(n_participants)} participants with fever in first 24 hours of admission")
print(f"N = {format_count(n_admissions)} inpatient admissions with fever")
print()
print("Inclusion criteria:")
print("  ✓ ≥2 temperature measurements >38°C in first 24 hours")
print("  ✓ ≥45 days potential follow-up before study end (2023-10-01)")
print("  ✓ Post-fever healthcare data OR death record")
print()

# Temperature statistics
print("Temperature Statistics:")
print(f"  - Median peak temperature: {fever_episodes_final['peak_temp'].median():.1f}°C")
print(f"  - Mean peak temperature: {fever_episodes_final['peak_temp'].mean():.1f}°C")
print(f"  - Max peak temperature: {fever_episodes_final['peak_temp'].max():.1f}°C")
print(f"  - Median fever measurements per admission: {fever_episodes_final['fever_measurements'].median():.0f}")
print()

# Distribution of fever measurements per admission
fever_count_dist = fever_episodes_final.group_by('fever_measurements').agg([
    pl.len().alias('n_admissions')
]).sort('fever_measurements')

print("Distribution of fever measurements (>38°C) per admission:")
for row in fever_count_dist.iter_rows(named=True):
    count = row['n_admissions']
    count_display = '<20' if count < 20 else f"{count:,}"
    print(f"  - {row['fever_measurements']} fever measurements: {count_display} admissions")

print()

# Follow-up data breakdown
n_with_post_visit = fever_episodes_final.filter(pl.col('has_post_fever_visit')).height
n_with_death_only = fever_episodes_final.filter(
    ~pl.col('has_post_fever_visit') & pl.col('has_death_record')
).height

print("Follow-up Data Breakdown:")
print(f"  - Admissions with post-fever visits: {format_count(n_with_post_visit)}")
print(f"  - Admissions with death (no post-fever visit): {format_count(n_with_death_only)}")

print()
print("=" * 60)
print(f"Ready for downstream analysis!")
print("=" * 60)