In [1]:
# %% [markdown]
# # Spatio-Temporal Analysis with STARMA
# 
# This notebook preprocesses the data and applies a STARMA model to analyze spatiotemporal patterns in disease outbreaks.

# %%
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import cKDTree
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

# %%
# Load datasets
def load_data():
    train_df = pd.read_csv("../data/raw/Train.csv")
    waste_df = pd.read_csv("../data/raw/waste_management.csv")
    toilets_df = pd.read_csv("../data/raw/toilets.csv")
    water_df = pd.read_csv("../data/raw/water_sources.csv")
    return train_df, waste_df, toilets_df, water_df

train_df, waste_df, toilets_df, water_df = load_data()

# %%
# Inspect datasets
print("Train Dataset:")
print(train_df.head())
print("\nWaste Management Dataset:")
print(waste_df.head())
print("\nToilets Dataset:")
print(toilets_df.head())
print("\nWater Sources Dataset:")
print(water_df.head())

# %%
# Check for missing values
print("Missing Values in Train Dataset:")
print(train_df.isnull().sum())
print("\nMissing Values in Waste Management Dataset:")
print(waste_df.isnull().sum())
print("\nMissing Values in Toilets Dataset:")
print(toilets_df.isnull().sum())
print("\nMissing Values in Water Sources Dataset:")
print(water_df.isnull().sum())

# %%
# Rename columns for clarity
def rename_columns(df, prefix):
    for col in df.columns:
        if col not in ['Month_Year_lat_lon', 'lat_lon']:
            df.rename(columns={col: f"{prefix}_{col}"}, inplace=True)

rename_columns(toilets_df, "toilet")
rename_columns(waste_df, "waste")
rename_columns(water_df, "water")

# %%
# Drop rows with missing latitude and longitude in supplementary datasets
for df, prefix in [(toilets_df, 'toilet'), (waste_df, 'waste'), (water_df, 'water')]:
    df.dropna(subset=[f"{prefix}_Transformed_Latitude", f"{prefix}_Transformed_Longitude"], inplace=True)

# %%
# Function to find nearest locations
def find_nearest(hospital_df, location_df, lat_col, lon_col, id_col):
    """
    Find the nearest location in `location_df` for each hospital in `hospital_df`.
    
    Parameters:
        hospital_df (pd.DataFrame): Hospital data with latitude and longitude.
        location_df (pd.DataFrame): Location data (e.g., waste, toilets, water sources).
        lat_col (str): Latitude column in `location_df`.
        lon_col (str): Longitude column in `location_df`.
        id_col (str): Unique identifier column in `location_df`.
    
    Returns:
        nearest (dict): A dictionary mapping hospital IDs to the nearest location ID.
    """
    # Create a cKDTree for efficient nearest neighbour search
    tree = cKDTree(location_df[[lat_col, lon_col]].values)
    nearest = {}
    # Loop through each hospital and find the nearest site in location_df
    for _, row in hospital_df.iterrows():
        _, idx = tree.query([row['Transformed_Latitude'], row['Transformed_Longitude']])
        nearest[row['ID']] = location_df.iloc[idx][id_col]
    return nearest

# %%
# Ensure unique identifier columns exist in all supplementary datasets
for df, prefix in [(toilets_df, 'toilet'), (waste_df, 'waste'), (water_df, 'water')]:
    df[f"{prefix}_Month_Year_lat_lon"] = (
        df[f"{prefix}_Month_Year"] + '_' +
        df[f"{prefix}_Transformed_Latitude"].astype(str) + '_' +
        df[f"{prefix}_Transformed_Longitude"].astype(str)
    )

# %%
# Merge datasets with nearest locations
merged_data = train_df.copy()
datasets = [
    (toilets_df, 'toilet', 'toilet_Month_Year_lat_lon'),
    (waste_df, 'waste', 'waste_Month_Year_lat_lon'),
    (water_df, 'water', 'water_Month_Year_lat_lon'),
]

for df, prefix, id_col in datasets:
    nearest = find_nearest(merged_data, df, f"{prefix}_Transformed_Latitude", f"{prefix}_Transformed_Longitude", id_col)
    nearest_df = pd.DataFrame(list(nearest.items()), columns=['ID', id_col])
    merged_data = merged_data.merge(nearest_df, on="ID").merge(df, on=id_col)

# %%
# Select relevant columns for modeling
model_data = merged_data[['ID', 'Year', 'Month', 'Total', 'Transformed_Latitude', 'Transformed_Longitude'] + 
              ['waste_2t', 'waste_tp', 'waste_swvl1', 'waste_swvl2', 'waste_swvl3', 'waste_swvl4', 'waste_10u', 'waste_10v']]  # Example climate variables

# %%
# Aggregate data by location and time
agg_data = model_data.groupby(['Transformed_Latitude', 'Transformed_Longitude', 'Year', 'Month']).agg({
    'Total': 'sum',  # Sum of outbreaks
    'waste_2t': 'mean',    # Average temperature
    'waste_tp': 'mean',    # Average precipitation
    'waste_swvl1': 'mean', # Average soil water content
}).reset_index()

# %%
# Create a time series for each location
time_series_data = agg_data.pivot_table(
    index=['Transformed_Latitude', 'Transformed_Longitude'],
    columns=['Year', 'Month'],
    values='Total'
).fillna(0)

# %%
# Prepare data for STARMA model
def prepare_starma_data(time_series_data):
    """
    Prepare data for STARMA modeling.
    
    Parameters:
        time_series_data (pd.DataFrame): Pivoted time series data.
    
    Returns:
        X (np.array): Feature matrix (time series for each location).
        y (np.array): Target variable (outbreaks).
    """
    X = time_series_data.values.T  # Transpose to get time series for each location
    y = time_series_data.sum(axis=1).values  # Total outbreaks per location
    return X, y

X, y = prepare_starma_data(time_series_data)

# %%
# Ensure X and y have the same number of samples
if X.shape[0] != y.shape[0]:
    raise ValueError(f"X and y have inconsistent numbers of samples: {X.shape[0]} vs {y.shape[0]}")

# %%
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# %%
# Fit STARMA model
def fit_starma_model(X_train, y_train, order=(1, 0, 1)):
    """
    Fit a STARMA model to the data.
    
    Parameters:
        X_train (np.array): Training feature matrix.
        y_train (np.array): Training target variable.
        order (tuple): (p, d, q) order of the STARMA model.
    
    Returns:
        model: Fitted STARMA model.
    """
    model = SARIMAX(y_train, exog=X_train, order=order)
    fitted_model = model.fit(disp=False)
    return fitted_model

starma_model = fit_starma_model(X_train, y_train, order=(1, 0, 1))

# %%
# Make predictions
def predict_starma(model, X_test):
    """
    Make predictions using the STARMA model.
    
    Parameters:
        model: Fitted STARMA model.
        X_test (np.array): Test feature matrix.
    
    Returns:
        predictions (np.array): Predicted values.
    """
    predictions = model.forecast(steps=len(X_test), exog=X_test)
    return predictions

predictions = predict_starma(starma_model, X_test)

# %%
# Evaluate model
mae = mean_absolute_error(y_test, predictions)
print(f"Mean Absolute Error (MAE): {mae}")

# %%
# Visualize predictions vs actual values
plt.figure(figsize=(10, 6))
plt.plot(y_test, label='Actual')
plt.plot(predictions, label='Predicted')
plt.title('STARMA Model: Actual vs Predicted Outbreaks')
plt.xlabel('Location Index')
plt.ylabel('Number of Outbreaks')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Train Dataset:
                                                  ID  Total  \
0  ID_3a11929e-3317-476d-99f7-1bd9fb58f018_12_202...    0.0   
1  ID_3a11929e-3317-476d-99f7-1bd9fb58f018_12_202...    0.0   
2  ID_3a11929e-3317-476d-99f7-1bd9fb58f018_12_202...    0.0   
3  ID_3a11929e-3317-476d-99f7-1bd9fb58f018_12_202...    0.0   
4  ID_3a11929e-3317-476d-99f7-1bd9fb58f018_12_202...    0.0   

                                  Location  \
0  ID_3a11929e-3317-476d-99f7-1bd9fb58f018   
1  ID_3a11929e-3317-476d-99f7-1bd9fb58f018   
2  ID_3a11929e-3317-476d-99f7-1bd9fb58f018   
3  ID_3a11929e-3317-476d-99f7-1bd9fb58f018   
4  ID_3a11929e-3317-476d-99f7-1bd9fb58f018   

          Category_Health_Facility_UUID    Disease  Month  Year  \
0  a9280aca-c872-46f5-ada7-4a7cc31cf6ec  Dysentery     12  2022   
1  a9280aca-c872-46f5-ada7-4a7cc31cf6ec    Typhoid     12  2022   
2  a9280aca-c872-46f5-ada7-4a7cc31cf6ec   Diarrhea     12  2022   
3  a9280aca-c872-46f5-ada7-4a7cc31cf6ec   Diarrhea     12  20

ValueError: X and y have inconsistent numbers of samples: 48 vs 57