In [1]:
import pandas as pd
import numpy as np

from stratified_ioi_subgrouping import (
    IOIsplitter,
    calculate_between_person_variability,
    calculate_within_person_variability,
    calculate_ioi,
)

In [2]:
data = pd.read_csv("data/simulated_data.csv")
wavenumbers = [str(i) for i in range(519)]

In [3]:
wpv = calculate_within_person_variability(data, wavenumbers)
bpv = calculate_between_person_variability(data, wavenumbers)
original_ioi = (wpv / bpv).mean().mean()

### 🧹 Data Preprocessing for Grouping

To ensure consistent grouping of subjects, all samples from the same subject must be assigned to the same group.  
Therefore, instead of using sample-level covariates (which can vary across visits), we compute the **subject-level mean** of each continuous covariate.  
These mean values are then used during the splitting process to determine group membership.

In [4]:
continuous_covariates = [
    "bmi",
    "age",
]
mean_values = data.groupby("subject_id")[continuous_covariates].transform("mean")
mean_values = mean_values.add_prefix("mean_")
data = pd.concat([data, mean_values], axis=1)

covariate_combination = [f"mean_{covariate}" for covariate in continuous_covariates] + [
    "sex"
]

In [5]:
splitter = IOIsplitter(
    data,
    covariates=covariate_combination,
    min_subjects_per_leaf=10,
    features=wavenumbers,
)
result = splitter.fit()
summaries, _ = splitter.summarize_splits(result)
summary_df = pd.DataFrame(summaries)
data["group"] = data.apply(lambda row: splitter.assign_to_leaf(summary_df, row), axis=1)
ioi_group = (
    data.groupby("group", observed=False)
    .apply(calculate_ioi, features=wavenumbers)
    .reset_index(drop=True)
)
average_ioi_within_groups = ioi_group[wavenumbers].mean().mean()
improvement = average_ioi_within_groups - original_ioi

  data.groupby("group", observed=False)


In [6]:
print(f"📊 Original IOI (no splitting):     {original_ioi:.4f}")
print(f"🌿 Mean IOI after group splitting:  {average_ioi_within_groups:.4f}")
print(f"✅ Increase in IOI:                 {improvement:.4f}")

📊 Original IOI (no splitting):     0.9003
🌿 Mean IOI after group splitting:  0.9718
✅ Increase in IOI:                 0.0715
