# 05 — Census Forecasting + Staffing Auto-Optimizer

This notebook:
1. Generates historical census patterns (synthetic but realistic)
2. Trains a lightweight ML model to predict next-day census per unit
3. Runs a staffing optimizer to recommend optimal nurse mix
4. Outputs predictions and recommendations to gold tables

**Serverless-compatible**: Uses pandas/sklearn instead of Spark ML.

In [None]:
%pip install scikit-learn==1.4.2 pandas numpy

In [None]:
dbutils.library.restartPython()

In [None]:
# Parameters (set by DAB or widget defaults)
dbutils.widgets.text("catalog", "rtpa_catalog")
dbutils.widgets.text("schema_ref", "credentialing_ref")
dbutils.widgets.text("schema_gold", "credentialing_gold")
dbutils.widgets.text("seed", "42")

catalog = dbutils.widgets.get("catalog")
schema_ref = dbutils.widgets.get("schema_ref")
schema_gold = dbutils.widgets.get("schema_gold")
seed = int(dbutils.widgets.get("seed"))

print(f"Catalog: {catalog}, Gold Schema: {schema_gold}, Seed: {seed}")

## 1. Generate Historical Census Data

Creates 90 days of realistic census patterns per unit with:
- Day-of-week seasonality (weekends lower)
- Unit-type patterns (ICU more stable, ED more volatile)
- Random noise

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

np.random.seed(seed)

# Load units from reference table with facility name
units_df = spark.sql(f"""
    SELECT u.unit_id, u.unit_name, u.unit_type, u.facility_id, u.bed_count, u.target_ratio,
           COALESCE(f.facility_name, u.facility_id) as facility_name
    FROM {catalog}.{schema_ref}.unit u
    LEFT JOIN {catalog}.{schema_ref}.facility f ON u.facility_id = f.facility_id
""").toPandas()
print(f"Loaded {len(units_df)} units")
display(units_df)

In [None]:
# Generate 90 days of historical census data
HISTORY_DAYS = 90
today = datetime.now().date()
dates = [today - timedelta(days=i) for i in range(HISTORY_DAYS, 0, -1)]

# Unit-type volatility (how much census fluctuates)
VOLATILITY = {
    "ICU": 0.10, "NICU": 0.12, "STEP_DOWN": 0.15,
    "MED_SURG": 0.20, "TELEMETRY": 0.18,
    "ED": 0.30, "OR": 0.25, "PACU": 0.20,
    "L_AND_D": 0.35, "PSYCH": 0.15
}

# Day-of-week multipliers (0=Mon, 6=Sun)
DOW_MULT = {0: 1.0, 1: 1.02, 2: 1.05, 3: 1.03, 4: 0.98, 5: 0.85, 6: 0.80}

census_records = []
for _, unit in units_df.iterrows():
    unit_id = unit["unit_id"]
    bed_count = unit["bed_count"]
    unit_type = unit["unit_type"]
    target_ratio = unit["target_ratio"]
    
    # Base occupancy (ICU higher, MED_SURG lower)
    base_occupancy = {"ICU": 0.85, "NICU": 0.80, "ED": 0.75, "MED_SURG": 0.70}.get(unit_type, 0.75)
    volatility = VOLATILITY.get(unit_type, 0.20)
    
    for dt in dates:
        dow = dt.weekday()
        dow_mult = DOW_MULT[dow]
        
        # Add trend (slight increase over time to simulate growth)
        day_idx = (dt - dates[0]).days
        trend = 1.0 + (day_idx / HISTORY_DAYS) * 0.05
        
        # Calculate census with noise
        noise = np.random.normal(0, volatility)
        occupancy = base_occupancy * dow_mult * trend * (1 + noise)
        census = int(np.clip(occupancy * bed_count, 1, bed_count))
        
        # Calculate required nurses based on ratio
        nurses_required = max(1, int(np.ceil(census / target_ratio)))
        
        census_records.append({
            "census_date": dt,
            "unit_id": unit_id,
            "facility_id": unit["facility_id"],
            "unit_type": unit_type,
            "bed_count": bed_count,
            "target_ratio": target_ratio,
            "census": census,
            "occupancy_pct": round(census / bed_count * 100, 1),
            "nurses_required": nurses_required,
            "day_of_week": dow,
            "is_weekend": dow >= 5
        })

census_history = pd.DataFrame(census_records)
print(f"Generated {len(census_history)} census records")
display(census_history.head(20))

In [None]:
# Save historical census to gold table
spark.createDataFrame(census_history).write.mode("overwrite").saveAsTable(
    f"{catalog}.{schema_gold}.census_history"
)
print(f"Wrote census_history to {catalog}.{schema_gold}.census_history")

## 2. Train Census Forecasting Model

Uses a simple Random Forest to predict next-day census based on:
- Unit characteristics (type, bed count)
- Recent census trends (lag features)
- Day of week

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

# Create lag features (census from previous days)
def create_features(df):
    df = df.sort_values(["unit_id", "census_date"]).copy()
    
    # Lag features per unit
    for lag in [1, 2, 3, 7]:  # yesterday, 2 days ago, 3 days ago, same day last week
        df[f"census_lag_{lag}"] = df.groupby("unit_id")["census"].shift(lag)
    
    # Rolling averages
    df["census_rolling_3d"] = df.groupby("unit_id")["census"].transform(
        lambda x: x.rolling(3, min_periods=1).mean()
    )
    df["census_rolling_7d"] = df.groupby("unit_id")["census"].transform(
        lambda x: x.rolling(7, min_periods=1).mean()
    )
    
    return df.dropna()

# Prepare features
df_features = create_features(census_history)
print(f"Training samples after feature engineering: {len(df_features)}")

# Encode categorical variables
le_unit_type = LabelEncoder()
le_unit = LabelEncoder()
df_features["unit_type_enc"] = le_unit_type.fit_transform(df_features["unit_type"])
df_features["unit_id_enc"] = le_unit.fit_transform(df_features["unit_id"])

# Feature columns
feature_cols = [
    "unit_type_enc", "unit_id_enc", "bed_count", "target_ratio",
    "day_of_week", "is_weekend",
    "census_lag_1", "census_lag_2", "census_lag_3", "census_lag_7",
    "census_rolling_3d", "census_rolling_7d"
]

X = df_features[feature_cols]
y = df_features["census"]

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

# Train model
model = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=seed, n_jobs=-1)
model.fit(X_train, y_train)

# Evaluate
y_pred = model.predict(X_test)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Model Performance:")
print(f"  MAE: {mae:.2f} patients")
print(f"  R²: {r2:.3f}")

In [None]:
# Feature importance
importance = pd.DataFrame({
    "feature": feature_cols,
    "importance": model.feature_importances_
}).sort_values("importance", ascending=False)
display(importance)

## 3. Generate 7-Day Census Forecast

Predicts census for the next 7 days per unit.

In [None]:
# Generate forecast for next 7 days
FORECAST_DAYS = 7
forecast_dates = [today + timedelta(days=i) for i in range(1, FORECAST_DAYS + 1)]

forecasts = []
for _, unit in units_df.iterrows():
    unit_id = unit["unit_id"]
    
    # Get recent history for this unit
    unit_history = census_history[census_history["unit_id"] == unit_id].sort_values("census_date")
    recent = unit_history.tail(7)["census"].tolist()
    
    for i, forecast_date in enumerate(forecast_dates):
        dow = forecast_date.weekday()
        
        # Build features for prediction
        features = {
            "unit_type_enc": le_unit_type.transform([unit["unit_type"]])[0],
            "unit_id_enc": le_unit.transform([unit_id])[0],
            "bed_count": unit["bed_count"],
            "target_ratio": unit["target_ratio"],
            "day_of_week": dow,
            "is_weekend": dow >= 5,
            "census_lag_1": recent[-1] if recent else unit["bed_count"] * 0.7,
            "census_lag_2": recent[-2] if len(recent) > 1 else unit["bed_count"] * 0.7,
            "census_lag_3": recent[-3] if len(recent) > 2 else unit["bed_count"] * 0.7,
            "census_lag_7": recent[-7] if len(recent) >= 7 else unit["bed_count"] * 0.7,
            "census_rolling_3d": np.mean(recent[-3:]) if recent else unit["bed_count"] * 0.7,
            "census_rolling_7d": np.mean(recent) if recent else unit["bed_count"] * 0.7,
        }
        
        # Predict
        X_pred = pd.DataFrame([features])
        predicted_census = int(np.clip(model.predict(X_pred)[0], 1, unit["bed_count"]))
        nurses_required = max(1, int(np.ceil(predicted_census / unit["target_ratio"])))
        
        # Calculate confidence (based on historical volatility)
        volatility = VOLATILITY.get(unit["unit_type"], 0.20)
        confidence = round(max(0.6, 1 - volatility) * 100)
        
        forecasts.append({
            "forecast_date": forecast_date,
            "unit_id": unit_id,
            "facility_id": unit["facility_id"],
            "facility_name": unit["facility_name"],
            "unit_name": unit["unit_name"],
            "unit_type": unit["unit_type"],
            "bed_count": unit["bed_count"],
            "predicted_census": predicted_census,
            "predicted_occupancy_pct": round(predicted_census / unit["bed_count"] * 100, 1),
            "nurses_required": nurses_required,
            "confidence_pct": confidence,
            "day_of_week": dow,
            "is_weekend": dow >= 5,
            "generated_at": datetime.now()
        })
        
        # Update recent for next iteration (rolling forecast)
        recent.append(predicted_census)
        recent = recent[-7:]

forecast_df = pd.DataFrame(forecasts)
print(f"Generated {len(forecast_df)} forecasts")
display(forecast_df)

In [None]:
# Save forecasts to gold table
spark.createDataFrame(forecast_df).write.mode("overwrite").saveAsTable(
    f"{catalog}.{schema_gold}.census_forecast"
)
print(f"Wrote forecasts to {catalog}.{schema_gold}.census_forecast")

## 4. Staffing Auto-Optimizer

Generates optimal staffing recommendations that minimize cost while meeting demand.

**Optimization logic:**
- Prioritize internal staff (cheapest)
- Use contract nurses for predictable gaps
- Agency only for urgent/unexpected needs
- Factor in credential requirements

In [None]:
# Hourly rates by employment type
HOURLY_RATES = {"INTERNAL": 50, "CONTRACT": 75, "AGENCY": 95}
SHIFT_HOURS = 12

# Current staffing from nurse_staffing_summary
try:
    current_staffing = spark.sql(f"""
        SELECT unit_id, nurses_internal, nurses_contract, nurses_agency, nurses_assigned
        FROM {catalog}.{schema_gold}.nurse_staffing_summary
        WHERE summary_date = (SELECT MAX(summary_date) FROM {catalog}.{schema_gold}.nurse_staffing_summary)
    """).toPandas()
    print(f"Loaded current staffing for {len(current_staffing)} units")
except:
    # Fallback if table doesn't exist
    current_staffing = pd.DataFrame()
    print("No current staffing data, using defaults")

In [None]:
def optimize_staffing(nurses_required, current_internal=0, current_contract=0, max_agency_pct=0.20):
    """
    Optimize staffing mix to minimize cost while meeting demand.
    
    Strategy:
    1. Use all available internal staff first
    2. Fill predictable gaps with contract
    3. Use agency only for remaining (up to max_agency_pct)
    """
    # Start with internal (cheapest)
    internal = min(current_internal, nurses_required)
    remaining = nurses_required - internal
    
    # Then contract
    contract = min(current_contract, remaining)
    remaining -= contract
    
    # Agency for the rest (capped at max %)
    max_agency = int(nurses_required * max_agency_pct)
    agency = min(remaining, max_agency)
    remaining -= agency
    
    # If still short, add more contract
    if remaining > 0:
        contract += remaining
    
    total = internal + contract + agency
    cost = (internal * HOURLY_RATES["INTERNAL"] + 
            contract * HOURLY_RATES["CONTRACT"] + 
            agency * HOURLY_RATES["AGENCY"]) * SHIFT_HOURS
    
    return {
        "opt_internal": internal,
        "opt_contract": contract,
        "opt_agency": agency,
        "opt_total": total,
        "opt_daily_cost": cost,
        "internal_pct": round(internal / total * 100, 1) if total > 0 else 0,
        "outsourced_pct": round((contract + agency) / total * 100, 1) if total > 0 else 0
    }

# Generate optimization recommendations
recommendations = []
for _, forecast in forecast_df.iterrows():
    unit_id = forecast["unit_id"]
    
    # Get current staffing for this unit
    if len(current_staffing) > 0 and unit_id in current_staffing["unit_id"].values:
        curr = current_staffing[current_staffing["unit_id"] == unit_id].iloc[0]
        current_internal = curr["nurses_internal"]
        current_contract = curr["nurses_contract"]
        current_total = curr["nurses_assigned"]
    else:
        # Default assumption: 60% internal capacity available
        current_internal = int(forecast["nurses_required"] * 0.6)
        current_contract = int(forecast["nurses_required"] * 0.2)
        current_total = 0
    
    # Run optimizer
    opt = optimize_staffing(
        nurses_required=forecast["nurses_required"],
        current_internal=current_internal,
        current_contract=current_contract
    )
    
    # Calculate current cost for comparison
    if current_total > 0:
        current_cost = (
            curr.get("nurses_internal", 0) * HOURLY_RATES["INTERNAL"] +
            curr.get("nurses_contract", 0) * HOURLY_RATES["CONTRACT"] +
            curr.get("nurses_agency", 0) * HOURLY_RATES["AGENCY"]
        ) * SHIFT_HOURS
    else:
        current_cost = opt["opt_daily_cost"]
    
    # Determine recommendation action
    delta = forecast["nurses_required"] - current_total if current_total > 0 else 0
    if delta > 0:
        action = f"STAFF_UP: Add {delta} nurse(s)"
        priority = "HIGH" if delta >= 2 else "MEDIUM"
    elif delta < 0:
        action = f"OPTIMIZE: Reduce {-delta} nurse(s)"
        priority = "LOW"
    else:
        action = "OPTIMAL: Staffing matches demand"
        priority = "LOW"
    
    recommendations.append({
        "forecast_date": forecast["forecast_date"],
        "unit_id": unit_id,
        "facility_id": forecast["facility_id"],
        "facility_name": forecast["facility_name"],
        "unit_name": forecast["unit_name"],
        "unit_type": forecast["unit_type"],
        "predicted_census": forecast["predicted_census"],
        "nurses_required": forecast["nurses_required"],
        "current_staffed": current_total,
        "staffing_delta": delta,
        **opt,
        "current_daily_cost": current_cost,
        "cost_savings": current_cost - opt["opt_daily_cost"],
        "action": action,
        "priority": priority,
        "confidence_pct": forecast["confidence_pct"],
        "generated_at": datetime.now()
    })

recommendations_df = pd.DataFrame(recommendations)
print(f"Generated {len(recommendations_df)} staffing recommendations")
display(recommendations_df)

In [None]:
# Save recommendations to gold table
spark.createDataFrame(recommendations_df).write.mode("overwrite").saveAsTable(
    f"{catalog}.{schema_gold}.staffing_optimization"
)
print(f"Wrote recommendations to {catalog}.{schema_gold}.staffing_optimization")

## 5. Summary: 7-Day Outlook

Aggregate view of predicted demand and optimization opportunities.

In [None]:
# Daily summary
daily_summary = recommendations_df.groupby("forecast_date").agg({
    "nurses_required": "sum",
    "opt_total": "sum",
    "opt_daily_cost": "sum",
    "cost_savings": "sum"
}).reset_index()
daily_summary.columns = ["date", "total_nurses_needed", "optimized_nurses", "optimized_cost", "potential_savings"]

print("7-Day Staffing Outlook:")
display(daily_summary)

total_savings = daily_summary["potential_savings"].sum()
print(f"\nTotal 7-day potential savings: ${total_savings:,.0f}")

In [None]:
# Units needing attention (high priority recommendations)
high_priority = recommendations_df[
    (recommendations_df["priority"] == "HIGH") & 
    (recommendations_df["forecast_date"] == forecast_dates[0])
][["unit_name", "facility_name", "predicted_census", "nurses_required", "current_staffed", "staffing_delta", "action"]]

print(f"\nHigh-priority units for tomorrow ({forecast_dates[0]}):")
if len(high_priority) > 0:
    display(high_priority)
else:
    print("All units optimally staffed!")

In [None]:
print("\n" + "="*60)
print("CENSUS FORECASTING + AUTO-OPTIMIZER COMPLETE")
print("="*60)
print(f"\nTables created/updated:")
print(f"  - {catalog}.{schema_gold}.census_history")
print(f"  - {catalog}.{schema_gold}.census_forecast")
print(f"  - {catalog}.{schema_gold}.staffing_optimization")
print(f"\nModel metrics:")
print(f"  - MAE: {mae:.2f} patients")
print(f"  - R²: {r2:.3f}")
print(f"\n7-day outlook:")
print(f"  - Total nurses needed: {daily_summary['total_nurses_needed'].sum()}")
print(f"  - Potential cost savings: ${total_savings:,.0f}")