In [None]:
# %% [markdown]
# # RCSQ 6-Axis Sleep Study: Synthetic Datasets and Analysis
#
# This notebook:
# - Uses a 6-axis Richards–Campbell Sleep Questionnaire (RCSQ)
#   - Depth, Latency, Awakenings, Return, Quality, Noise
# - Generates 4 synthetic datasets (each: 50 control + 50 intervention)
#   1. randomized_data: fully random scores (0–100)
#   2. no_effect_data: realistic sleep scores, no intervention effect
#   3. small_effect_data: realistic sleep scores, small intervention effect
#   4. large_effect_data: realistic sleep scores, large intervention effect
# - Scores RCSQ total as the mean of the 6 axes (0–100)
# - Computes Cronbach's alpha, t-tests, Mann–Whitney, and Cohen's d
# - Displays a Markdown summary table of the results


# %% 
# 1. Imports and configuration

import numpy as np
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Markdown, display

# For reproducibility
RNG_SEED = 42
rng = np.random.default_rng(RNG_SEED)

# Display options
pd.set_option("display.max_columns", 50)
pd.set_option("display.width", 120)

# Plotting style
plt.style.use("default")
sns.set(context="notebook")


# %% [markdown]
# ## 2. Define RCSQ structure and helper functions
#
# We treat RCSQ as 6 axes:
# - Depth
# - Latency
# - Awakenings
# - Return (return to sleep)
# - Quality
# - Noise (environmental noise)


# %% 
# Column names for the 6-axis instrument
RCSQ_COLS = [
    "rcsq_depth",
    "rcsq_latency",
    "rcsq_awakenings",
    "rcsq_return",
    "rcsq_quality",
    "rcsq_noise",
]

RCSQ_LABELS = {
    "rcsq_depth": "Depth",
    "rcsq_latency": "Latency",
    "rcsq_awakenings": "Awakenings",
    "rcsq_return": "Return",
    "rcsq_quality": "Quality",
    "rcsq_noise": "Noise",
}


def score_rcsq_total(df: pd.DataFrame, item_cols=RCSQ_COLS, out_col="rcsq_total"):
    """
    Score RCSQ total as the mean of available items (0–100) for each row.
    """
    df[out_col] = df[item_cols].mean(axis=1, skipna=True)
    return df


def categorize_sleep(score: float, cutoff: float = 50.0) -> str:
    """
    Categorize sleep based on RCSQ total score.
    """
    if pd.isna(score):
        return np.nan
    return "poor_sleep" if score < cutoff else "good_sleep"


def cronbach_alpha(df_items: pd.DataFrame) -> float:
    """
    Compute Cronbach's alpha for a set of item columns.
    """
    item_scores = df_items.dropna(axis=0)  # drop rows with missing items
    k = item_scores.shape[1]
    if k < 2:
        return np.nan
    variances = item_scores.var(axis=0, ddof=1)
    total_var = item_scores.sum(axis=1).var(ddof=1)
    if total_var == 0:
        return np.nan
    alpha = (k / (k - 1)) * (1 - variances.sum() / total_var)
    return float(alpha)


def cohen_d(x: np.ndarray, y: np.ndarray) -> float:
    """
    Compute Cohen's d for independent samples x and y.
    """
    x = np.asarray(x)
    y = np.asarray(y)
    nx = len(x)
    ny = len(y)
    if nx < 2 or ny < 2:
        return np.nan
    pooled_sd = np.sqrt(((nx - 1) * x.var(ddof=1) + (ny - 1) * y.var(ddof=1)) / (nx + ny - 2))
    if pooled_sd == 0:
        return np.nan
    return float((x.mean() - y.mean()) / pooled_sd)


def clip_0_100(array: np.ndarray) -> np.ndarray:
    """
    Clip values to [0, 100] bounds (VAS scale).
    """
    return np.clip(array, 0, 100)


# %% [markdown]
# ## 3. Synthetic data generation
#
# We create 4 datasets, each with 100 subjects:
# - 50 control
# - 50 intervention
#
# We vary the *true* mean difference between groups:
# - randomized_data: scores are uniform(0, 100) for all subjects, group labels assigned but not linked
# - no_effect_data: realistic means, identical distribution in both groups
# - small_effect_data: intervention mean slightly higher (small effect size)
# - large_effect_data: intervention mean much higher (large effect size)


# %% 
def generate_randomized_data(n_control=50, n_intervention=50) -> pd.DataFrame:
    """
    Fully random RCSQ data: each axis uniform(0, 100), independent of group.
    Used to simulate a 'wild' dataset with no designed structure.
    """
    n = n_control + n_intervention
    patient_ids = np.arange(1, n + 1)

    # Random group assignment but scores independent of group
    groups = np.array(["control"] * n_control + ["intervention"] * n_intervention)
    rng.shuffle(groups)

    # Uniform random scores for each axis
    scores = rng.uniform(0, 100, size=(n, len(RCSQ_COLS)))
    scores = clip_0_100(scores)

    df = pd.DataFrame(scores, columns=RCSQ_COLS)
    df.insert(0, "patient_id", patient_ids)
    df.insert(1, "group", groups)

    score_rcsq_total(df)
    df["sleep_category"] = df["rcsq_total"].apply(categorize_sleep)

    return df


def generate_effect_data(
    n_control=50,
    n_intervention=50,
    control_means=None,
    intervention_means=None,
    sds=None,
) -> pd.DataFrame:
    """
    Generate RCSQ data for control and intervention groups with specified means and SDs.
    Means and SDs are length-6 vectors (for the 6 axes).
    """
    if control_means is None:
        control_means = np.array([60, 55, 50, 55, 60, 50], dtype=float)
    if intervention_means is None:
        intervention_means = np.array([60, 55, 50, 55, 60, 50], dtype=float)
    if sds is None:
        sds = np.array([15, 15, 15, 15, 15, 15], dtype=float)

    # Control scores
    control_scores = rng.normal(loc=control_means, scale=sds, size=(n_control, len(RCSQ_COLS)))
    control_scores = clip_0_100(control_scores)

    # Intervention scores
    intervention_scores = rng.normal(
        loc=intervention_means, scale=sds, size=(n_intervention, len(RCSQ_COLS))
    )
    intervention_scores = clip_0_100(intervention_scores)

    # Assemble dataframe
    df_control = pd.DataFrame(control_scores, columns=RCSQ_COLS)
    df_control.insert(0, "patient_id", np.arange(1, n_control + 1))
    df_control.insert(1, "group", "control")

    df_intervention = pd.DataFrame(intervention_scores, columns=RCSQ_COLS)
    df_intervention.insert(0, "patient_id", np.arange(n_control + 1, n_control + n_intervention + 1))
    df_intervention.insert(1, "group", "intervention")

    df = pd.concat([df_control, df_intervention], ignore_index=True)
    score_rcsq_total(df)
    df["sleep_category"] = df["rcsq_total"].apply(categorize_sleep)

    return df


# Generate each dataset:

# 1) Randomized data: fully random scores
randomized_data = generate_randomized_data()

# 2) No-effect data: same distribution in both groups (realistic sleep)
base_means = np.array([60, 55, 50, 55, 60, 50], dtype=float)
base_sds = np.array([12, 12, 12, 12, 12, 12], dtype=float)
no_effect_data = generate_effect_data(
    control_means=base_means,
    intervention_means=base_means,
    sds=base_sds,
)

# 3) Small effect: intervention slightly better (approx Cohen's d ~ 0.3–0.4)
small_effect_control_means = np.array([60, 55, 50, 55, 60, 50], dtype=float)
small_effect_intervention_means = np.array([66, 61, 56, 61, 66, 56], dtype=float)  # +6 points each axis
small_effect_sds = np.array([12, 12, 12, 12, 12, 12], dtype=float)
small_effect_data = generate_effect_data(
    control_means=small_effect_control_means,
    intervention_means=small_effect_intervention_means,
    sds=small_effect_sds,
)

# 4) Large effect: intervention much better
large_effect_control_means = np.array([55, 50, 45, 50, 55, 45], dtype=float)
large_effect_intervention_means = np.array([80, 75, 70, 75, 80, 70], dtype=float)  # big effect
large_effect_sds = np.array([12, 12, 12, 12, 12, 12], dtype=float)
large_effect_data = generate_effect_data(
    control_means=large_effect_control_means,
    intervention_means=large_effect_intervention_means,
    sds=large_effect_sds,
)

datasets = {
    "randomized_data": randomized_data,
    "no_effect_data": no_effect_data,
    "small_effect_data": small_effect_data,
    "large_effect_data": large_effect_data,
}

for name, d in datasets.items():
    print(f"{name}: shape = {d.shape}, mean RCSQ total = {d['rcsq_total'].mean():.2f}")


# %% [markdown]
# ## 4. Analysis helpers: group comparisons
#
# We'll define a function to:
# - compute Cronbach's alpha for the 6 axes
# - summarize means by group
# - run Welch t-test and Mann–Whitney U test
# - compute Cohen's d


# %% 
def analyze_dataset(df: pd.DataFrame, name: str):
    """
    Analyze a single RCSQ dataset:
    - Cronbach's alpha
    - group means
    - Welch t-test
    - Mann–Whitney U
    - Cohen's d

    Returns a summary dict with key statistics.
    """
    # Reliability
    alpha = cronbach_alpha(df[RCSQ_COLS])

    # Groups
    groups = df["group"].unique()
    if len(groups) != 2:
        raise ValueError(f"Dataset {name} does not have exactly 2 groups.")

    g1, g2 = sorted(groups)  # alphabetic order
    scores_g1 = df.loc[df["group"] == g1, "rcsq_total"].dropna().to_numpy()
    scores_g2 = df.loc[df["group"] == g2, "rcsq_total"].dropna().to_numpy()

    mean_g1 = float(scores_g1.mean())
    mean_g2 = float(scores_g2.mean())
    diff = mean_g2 - mean_g1  # intervention - control if groups ordered that way

    # Welch t-test
    t_stat, p_t = stats.ttest_ind(scores_g1, scores_g2, equal_var=False)

    # Mann–Whitney U
    u_stat, p_u = stats.mannwhitneyu(scores_g1, scores_g2, alternative="two-sided")

    # Cohen's d (order g1, g2)
    d = cohen_d(scores_g1, scores_g2)

    summary = {
        "dataset": name,
        "group_1": g1,
        "group_2": g2,
        "alpha": alpha,
        "mean_group_1": mean_g1,
        "mean_group_2": mean_g2,
        "mean_diff_g2_minus_g1": diff,
        "t_stat": float(t_stat),
        "p_t": float(p_t),
        "u_stat": float(u_stat),
        "p_u": float(p_u),
        "cohen_d_g2_minus_g1": float(d),
        "n_group_1": len(scores_g1),
        "n_group_2": len(scores_g2),
    }

    return summary


# Run analysis on all datasets
summaries = []
for name, df in datasets.items():
    s = analyze_dataset(df, name)
    summaries.append(s)

summary_df = pd.DataFrame(summaries)
summary_df


# %% [markdown]
# ## 5. Quick visualization examples (optional)
#
# Some quick plots on the large-effect dataset to visually inspect the difference.


# %%
example_df = large_effect_data.copy()

fig, ax = plt.subplots(figsize=(8, 5))
for group_name, group_df in example_df.groupby("group"):
    sns.kdeplot(group_df["rcsq_total"], ax=ax, label=group_name, fill=True, alpha=0.3)
ax.set_title("RCSQ Total Distribution by Group (large_effect_data)")
ax.set_xlabel("RCSQ Total (0–100)")
ax.legend()
plt.tight_layout()

fig, ax = plt.subplots(figsize=(6, 5))
sns.boxplot(data=example_df, x="group", y="rcsq_total", ax=ax)
ax.set_title("RCSQ Total by Group (large_effect_data)")
ax.set_xlabel("Group")
ax.set_ylabel("RCSQ Total (0–100)")
plt.tight_layout()


# %% [markdown]
# ## 6. Save synthetic datasets (optional)
#
# If you want to inspect the CSV files directly or use them in another notebook / program.


# %%
for name, df in datasets.items():
    filename = f"{name}.csv"
    df.to_csv(filename, index=False)
    print(f"Saved: {filename}")


# %% [markdown]
# ## 7. Markdown summary of key statistics
#
# This cell builds a markdown table summarizing:
# - Cronbach's alpha
# - group means
# - mean difference
# - Cohen's d
# - t-test p-value
#
# The output is rendered as Markdown in the notebook.


# %%
def make_markdown_summary_table(summary_df: pd.DataFrame) -> str:
    """
    Build a markdown table summarizing key results from summary_df.
    """
    header = (
        "| Dataset | Group 1 | Group 2 | α (Cronbach) | Mean G1 | Mean G2 | Diff (G2−G1) | Cohen d (G2−G1) | p (Welch t) |\n"
        "|---------|---------|---------|--------------|---------|---------|--------------|-----------------|------------|\n"
    )
    rows = []
    for _, row in summary_df.iterrows():
        rows.append(
            f"| {row['dataset']} "
            f"| {row['group_1']} "
            f"| {row['group_2']} "
            f"| {row['alpha']:.3f} "
            f"| {row['mean_group_1']:.2f} "
            f"| {row['mean_group_2']:.2f} "
            f"| {row['mean_diff_g2_minus_g1']:.2f} "
            f"| {row['cohen_d_g2_minus_g1']:.2f} "
            f"| {row['p_t']:.4f} |"
        )
    return header + "\n".join(rows)


md_table = make_markdown_summary_table(summary_df)
display(Markdown("### Summary of 6-Axis RCSQ Synthetic Datasets\n\n" + md_table))
