### Imports

In [1]:
import warnings
import pandas as pd
import numpy as np
import ast
import re
from datetime import date
from scipy.stats import linregress
from google.colab import drive
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm, colors as mcolors
from sklearn.covariance import EllipticEnvelope
import ipywidgets as widgets
from ipywidgets import interactive_output, HBox, VBox

drive.mount('/content/drive')
warnings.filterwarnings('ignore')
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

Mounted at /content/drive


### Preprocessing

In [None]:
subjective = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/Data/raw_subjective_data_2025-11-06.csv')

subjective.drop(['hasOtherHealthProblems',
                 'hrvMeasurement',
                 'stressType',
                 'otherSymptoms_x',
                 'otherHealthProblems',
                 'mentalStressLevel',
                 'medication',
                 'otherMedication_x',
                 'otherSymptoms_y',
                 'alcoholPortions',
                 'notes',
                 'diseaseRelapsesEvaluation',
                 'generalCondition',
                 'complications',
                 'stoolsPerDay',
                 'stoolsPerNight',
                 'urgencyOfDefecation',
                 'bloodInStool',
                 'stomachPain',
                 'unformedStoolsPerDay',
                 'abdominalResistance',
                 'terraUserId',
                 'additionalIllnesses',
                 'otherKnownCatalysts',
                 'flaresPerYear',
                 'hrvMeasurementMethod',
                 'backendSymptoms',
                 'hrvMeasurementMethodName',
                 'hasAskedToConnectWearable',
                 'hasConnectedWearable',
                 'knownCatalysts',
                 'backendKnownCatalysts',
                 'connectedWearableName',
                 'hasStoma',
                 'otherAdditionalIllnesses'
                 ], axis=1, inplace=True)

# Helper Functions
def year_to_age(year_of_birth):
    current_year = date.today().year
    try:
        year = float(year_of_birth)
        if year <= 0 or year > current_year:
            return None
        return int(current_year - year)
    except (ValueError, TypeError):
        return None

# Mappings
activity_mapping = {
        'zero': 0,
        'below30min': 1,
        'below1h': 2,
        'below2h': 3,
        'below4h': 4,
        'below8h': 5,
        'above8h': 6
    }

diagnosis_mapping = {
        "colitisUlcerosa": "UC",
        "crohnsDisease": "CD",
        "Crohn's disease": "CD",
    }

gender_mapping = {
        "female": "F",
        "male": "M"
    }

alcohol_mapping = {
        "Yes": 2,
        "A little": 1,
        "No": 0,
        "Unsure": np.nan
    }

period_mapping = {
        "Yes": 1,
        "Unsure": 0,
        "False": np.nan,
        "No": 0,
    }

subjective['date'] = pd.to_datetime(subjective['date'], format='mixed').dt.normalize()
subjective['age'] = subjective['yearOfBirth'].apply(year_to_age)

subjective['hasConsumedAlcoholInLast24Hours'] = subjective['hasConsumedAlcoholInLast24Hours'].map(alcohol_mapping)
subjective['activity_dur'] = subjective['physicalEffort'].map(activity_mapping)
subjective['diagnosis'] = subjective['diagnosis'].map(diagnosis_mapping)
subjective['isOnPeriod'] = subjective['isOnPeriod'].map(period_mapping)
subjective['gender'] = subjective['gender'].map(gender_mapping)

subjective.dropna(subset=["gender", "diagnosis"], inplace=True)

# Renaming
subjective = subjective.rename(columns={
                 'userId': 'user_id',
                 'sleepQualityDegree': 'sleep',
                 'stressLevelDegree': 'stress',
                 'physicalActivityExertionDegree': 'activity_deg',
                 'symptomDegree': 'symptom_deg',
                 'hasConsumedAlcoholInLast24Hours': 'alcohol_last_24h',
                 'isOnPeriod': 'on_period',
                 'rateAsFlare': 'rate_as_flare',
                 })

# Reordering columns
subjective = subjective[[
                 'user_id',
                 'date',
                 'gender',
                 'age',
                 'diagnosis',
                 'symptoms',
                 'alcohol_last_24h',
                 'on_period',
                 'sleep',
                 'stress',
                 'activity_dur',
                 'activity_deg',
                 'symptom_deg',
                 'rate_as_flare'
                 ]]

subjective.head()

In [None]:
def fit_full_cosinor(hrv_values, sampling_interval_minutes=5.0):

    y = np.asarray(hrv_values, dtype=float)
    y = y[~np.isnan(y)]

    # Need enough points to fit 3 parameters
    if y.size < 4:
        return np.nan, np.nan, np.nan, np.nan

    # time vector in hours
    dt = sampling_interval_minutes / 60.0
    t = np.arange(y.size, dtype=float) * dt

    # Total duration T (hours) = "night length" for this recording
    T = y.size * dt
    omega_fixed = 2.0 * np.pi / T  # fixed angular frequency

    # Model with fixed period T (omega fixed)
    def cosinor_model_fixed(t, mesor, amplitude, acrophase):
        return mesor + amplitude * np.cos(omega_fixed * t + acrophase)

    # Initial guesses
    mesor0 = y.mean()
    amplitude0 = (y.max() - y.min()) / 2.0
    acrophase0 = 0.0
    p0 = (mesor0, amplitude0, acrophase0)

    try:
        params, _ = curve_fit(cosinor_model_fixed, t, y, p0=p0, maxfev=5000)
        mesor, amplitude, acrophase = params
    except Exception:
        return np.nan, np.nan, np.nan, np.nan

    # Enforce amplitude >= 0 by flipping sign if necessary
    if amplitude < 0:
        amplitude = -amplitude
        acrophase = (acrophase + np.pi) % (2.0 * np.pi)

    # Peak time (max of the cosine) in hours, constrained to [0, T)
    # peak when omega*t + acrophase = 0 mod 2π -> t = -acrophase/omega
    peak_time = (-acrophase / omega_fixed) % T

    return mesor, acrophase, amplitude, peak_time

# -------------------------------------------------------------------
# Load data
# -------------------------------------------------------------------
raw_summary_wearable_data = pd.read_csv(
    '/content/drive/My Drive/coreway_ml/Thesis - Mika/Data/raw_summary_wearable_data_2025-09-26.csv'
)
raw_flattened_wearable_data = pd.read_csv(
    '/content/drive/My Drive/coreway_ml/Thesis - Mika/Data/raw_flattened_wearable_data_2025-09-26.csv'
)

# -------------------------------------------------------------------
# Transform stringified HRV list to actual list
# -------------------------------------------------------------------
raw_flattened_wearable_data['hrv_rmssd'] = raw_flattened_wearable_data['hrv_rmssd'].apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

# -------------------------------------------------------------------
# Standardize date columns for merging
# -------------------------------------------------------------------
raw_summary_wearable_data['date'] = pd.to_datetime(
    raw_summary_wearable_data['date'].str.slice(0, 19), errors='coerce'
)
raw_flattened_wearable_data['date'] = pd.to_datetime(
    raw_flattened_wearable_data['date'].str.slice(0, 19), errors='coerce'
)

objective = pd.merge(
    raw_summary_wearable_data,
    raw_flattened_wearable_data,
    on=['userId', 'date'],
    how='inner'
)

# -------------------------------------------------------------------
# Use end_time's calendar date as sleep date, extract HH:MM from start/end
# -------------------------------------------------------------------
objective['date'] = pd.to_datetime(objective['end_time'].str.slice(0, 10), errors='coerce')

objective[['start_time', 'end_time']] = (
    objective[['start_time', 'end_time']]
    .astype('string')
    .apply(lambda col: col.str.extract(r'T(\d{2}:\d{2})', expand=False))
)

# -------------------------------------------------------------------
# Drop unneeded columns
# -------------------------------------------------------------------
objective.drop([
    'terra_user_id',
    'provider_y',
    'sleep_score',
    'delta_temperature',
    'user_max_hr_bpm',
    'on_demand_reading',
    'breaths_start_time',
    'breaths_end_time',
    'max_breaths_per_min',
    'min_breaths_per_min',
    'duration_in_bed_seconds',
    'num_REM_events',
    'duration_long_interruption_seconds',
    'duration_short_interruption_seconds',
    'num_out_of_bed_events',
    'num_wakeup_events',
    'sleep_latency_seconds',
    'wake_up_latency_seconds',
    'max_hr_bpm',
    'min_hr_bpm',
    'bpm_array_length',
    'timestamp_intervals_seconds_hrv_rmssd',
    'hrv_rmssd_array_length',
    'level',
    'timestamp_intervals_seconds_level',
    'level_array_length',
    'percentage_array_length',
    'breaths_per_min_array_length',
    'timestamp_intervals_seconds_hrv_sdnn',
    'timestamp_intervals_seconds_bpm',
    'timestamp_intervals_seconds_percentage',
    'timestamp_intervals_seconds_breaths_per_min',
    'hrv_sdnn_array_length'
], axis=1, inplace=True)

# -------------------------------------------------------------------
# Renaming columns
# -------------------------------------------------------------------
objective = objective.rename(columns={
    'userId': 'user_id',
    'provider_x': 'provider',
    'start_time': 'start',
    'end_time': 'end',
    'avg_hr_bpm': 'avg_bpm',
    'resting_hr_bpm': 'rhr',
    'sleep_efficiency': 'sleep_eff',
    'breaths_per_min': 'breaths',
    'avg_breaths_per_min': 'avg_breaths',
    'avg_saturation_percentage': 'avg_SpO2',
    'duration_REM_sleep_state_seconds': 'dur_REM',
    'duration_asleep_state_seconds': 'dur_asleep',
    'duration_deep_sleep_state_seconds': 'dur_deep',
    'duration_light_sleep_state_seconds': 'dur_light',
    'duration_awake_state_seconds': 'dur_awake',
    'percentage': 'SpO2'
})

# -------------------------------------------------------------------
# Infer sleep length (HH:MM) and keep numeric duration in hours
# -------------------------------------------------------------------
start_dt = pd.to_datetime(objective['start'], format='%H:%M', errors='coerce')
end_dt   = pd.to_datetime(objective['end'],   format='%H:%M', errors='coerce')

# Handle crossing midnight (fixed boolean -> timedelta)
mask = (end_dt < start_dt).astype(int)
end_dt_corrected = end_dt + pd.to_timedelta(mask, unit="D")

diff = end_dt_corrected - start_dt
length_hours = diff.dt.total_seconds() / 3600.0

objective['length'] = (
    (diff.dt.seconds // 3600).astype(str).str.zfill(2) + ":" +
    ((diff.dt.seconds % 3600) // 60).astype(str).str.zfill(2)
)
objective['length_hours'] = length_hours

# -------------------------------------------------------------------
# Remove duplicates & aggregate; keep longest sleep per day
# -------------------------------------------------------------------
objective = (
    objective
    .groupby(['user_id', 'date', 'start', 'end'], as_index=False)
    .agg(lambda x: x.dropna().iloc[0] if x.notna().any() else np.nan)
)

idx_longest = objective.groupby(['user_id', 'date'])['length_hours'].idxmax()
objective = objective.loc[idx_longest].reset_index(drop=True)

# -------------------------------------------------------------------
# Convert durations from seconds to hours & compute sleep stage percentages
# -------------------------------------------------------------------
objective[['dur_asleep', 'dur_REM', 'dur_deep', 'dur_light', 'dur_awake']] /= 3600.0

objective[['REM_pct', 'deep_pct', 'light_pct']] = (
    objective[['dur_REM', 'dur_deep', 'dur_light']]
    .div(objective['dur_asleep'], axis=0) * 100.0
)

objective['sleep_eff'] = objective['sleep_eff'] * 100.0

# -------------------------------------------------------------------
# Drop rows with implausible sleep metrics
# -------------------------------------------------------------------
objective = objective[objective['dur_asleep'] >= 2]

objective = objective[
    objective['REM_pct'].between(0, 40) &
    objective['light_pct'].between(0, 90) &
    objective['deep_pct'].between(0, 40)
]

total_pct = objective['REM_pct'] + objective['light_pct'] + objective['deep_pct']
objective = objective[total_pct.between(90, 110)]

objective['sleep_eff'] = objective['sleep_eff'].where(
    objective['sleep_eff'].between(25, 100), np.nan
)

# -------------------------------------------------------------------
# Replace non-physiological values with NaNs
# -------------------------------------------------------------------
objective['avg_hrv_sdnn'] = objective['avg_hrv_sdnn'].where(
    objective['avg_hrv_sdnn'].between(5, 300), np.nan
)
objective['avg_hrv_rmssd'] = objective['avg_hrv_rmssd'].where(
    objective['avg_hrv_rmssd'].between(5, 300), np.nan
)
objective['avg_SpO2'] = objective['avg_SpO2'].where(
    objective['avg_SpO2'].between(85, 100), np.nan
)
objective['avg_bpm'] = objective['avg_bpm'].where(
    objective['avg_bpm'].between(30, 150), np.nan
)
objective['rhr'] = objective['rhr'].where(
    objective['rhr'].between(30, 150), np.nan
)

# -------------------------------------------------------------------
# Derive HRV features + cosinor features from hrv_rmssd (single pass)
# -------------------------------------------------------------------
std_list = []
cv_list = []
min_list = []
max_list = []
slope_list = []
mesor_list = []
acrophase_list = []
amplitude_list = []
peak_time_list = []

for x in objective['hrv_rmssd']:
    if isinstance(x, (list, np.ndarray, pd.Series)) and len(x) > 0:
        arr = np.asarray(x, dtype='float64')

        # Basic HRV stats
        std_val = arr.std()
        mean_val = arr.mean()
        min_val = arr.min()
        max_val = arr.max()

        if arr.size > 1:
            slope = linregress(np.arange(arr.size), arr).slope
        else:
            slope = np.nan

        # Cosinor with period exactly matching this recording's length
        mesor, acrophase, amplitude, peak_time = fit_full_cosinor(
            arr, sampling_interval_minutes=5.0
        )

        std_list.append(std_val)
        cv_list.append(std_val / mean_val if mean_val != 0 else np.nan)
        min_list.append(min_val)
        max_list.append(max_val)
        slope_list.append(slope)
        mesor_list.append(mesor)
        acrophase_list.append(acrophase)
        amplitude_list.append(amplitude)
        peak_time_list.append(peak_time)
    else:
        std_list.append(np.nan)
        cv_list.append(np.nan)
        min_list.append(np.nan)
        max_list.append(np.nan)
        slope_list.append(np.nan)
        mesor_list.append(np.nan)
        acrophase_list.append(np.nan)
        amplitude_list.append(np.nan)
        peak_time_list.append(np.nan)

objective['std_rmssd'] = std_list
objective['cv_rmssd'] = cv_list
objective['min_rmssd'] = min_list
objective['max_rmssd'] = max_list
objective['range_rmssd'] = objective['max_rmssd'] - objective['min_rmssd']
objective['slope_rmssd'] = slope_list

objective['mesor_rmssd'] = mesor_list
objective['acrophase_rmssd'] = acrophase_list
objective['amplitude_rmssd'] = amplitude_list
objective['peak_time_rmssd'] = peak_time_list

# -------------------------------------------------------------------
# Reordering columns (now including cosinor features)
# -------------------------------------------------------------------
objective = objective[[
    'user_id',
    'date',
    'provider',
    'start',
    'end',
    'length',
    'length_hours',
    'sleep_eff',
    'dur_asleep',
    'dur_REM',
    'REM_pct',
    'dur_deep',
    'deep_pct',
    'dur_light',
    'light_pct',
    'dur_awake',
    'hrv_rmssd',
    'avg_hrv_rmssd',
    'std_rmssd',
    'cv_rmssd',
    'min_rmssd',
    'max_rmssd',
    'range_rmssd',
    'slope_rmssd',
    'mesor_rmssd',
    'acrophase_rmssd',
    'amplitude_rmssd',
    'peak_time_rmssd',
    'avg_hrv_sdnn',
    'avg_bpm',
    'rhr',
    'avg_SpO2',
    'avg_breaths'
]]

# -------------------------------------------------------------------
# Sorting
# -------------------------------------------------------------------
objective = objective.sort_values(
    by=['user_id', 'date', 'start', 'end']
).reset_index(drop=True)

objective.head()

In [None]:
merged = pd.merge(subjective, objective, on=['user_id', 'date'], how='outer', indicator=True)

merged.head()

### Interactive Plot

In [None]:
# Ensure date is datetime
merged['date'] = pd.to_datetime(merged['date'])

# Map friendly names to _merge values
DATA_TYPE_MAP = {
    'multimodal': 'both',
    'objective': 'right_only',
    'subjective': 'left_only'
}

data_type_filter = widgets.ToggleButtons(
    options=[
        ('Multimodal', 'multimodal'),
        ('Objective', 'objective'),
        ('Subjective', 'subjective')
    ],
    value='multimodal',
    description='Data type:',
    disabled=False
)

min_days_filter = widgets.BoundedIntText(
    value=100,
    min=1,
    max=10000,
    step=10,
    description='Min days:',
    layout=widgets.Layout(width='160px')
)

def _get_filtered_users(data_type_value, min_days):

    merge_value = DATA_TYPE_MAP[data_type_value]
    mask = merged['_merge'] == merge_value
    counts = merged[mask].groupby('user_id').size()
    eligible = counts[counts >= min_days].index.tolist()
    return sorted(eligible)

# ----------------------------
# Widgets
# ----------------------------

initial_users = _get_filtered_users(data_type_filter.value, min_days_filter.value)
user_selector = widgets.ToggleButtons(
    options=initial_users,
    description='user_id:',
    disabled=False,
)

def _update_user_options(change=None):

    eligible = _get_filtered_users(data_type_filter.value, min_days_filter.value)
    if len(eligible) == 0:
        user_selector.options = []
        return

    # Update options
    prev_value = user_selector.value if user_selector.value in eligible else None
    user_selector.options = eligible

    # If previous selection invalid or None, default to first eligible user
    if prev_value is None:
        user_selector.value = eligible[0]

# Attach observers
data_type_filter.observe(_update_user_options, names='value')
min_days_filter.observe(_update_user_options, names='value')

feature_options = [
    'sleep', 'stress', 'activity_dur', 'activity_deg',
    'sleep_eff', 'dur_asleep', 'dur_REM', 'REM_pct', 'dur_deep',
    'deep_pct', 'dur_light', 'light_pct', 'dur_awake', 'avg_hrv_rmssd',
    'std_rmssd', 'cv_rmssd', 'min_rmssd', 'max_rmssd', 'range_rmssd',
    'slope_rmssd', 'mesor_rmssd', 'acrophase_rmssd', 'amplitude_rmssd',
    'peak_time_rmssd', 'avg_hrv_sdnn', 'avg_bpm', 'rhr', 'avg_SpO2',
    'avg_breaths'
]

feature_selector = widgets.SelectMultiple(
    options=feature_options,
    value=('stress', 'REM_pct', 'deep_pct', 'light_pct', 'rhr'),
    description='Features',
    layout=widgets.Layout(width='250px', height='220px')
)

smooth_checkbox = widgets.Checkbox(
    value=True,
    description='Apply smoothing',
    indent=False
)

smooth_window = widgets.BoundedIntText(
    value=30,
    min=1,
    max=60,
    step=1,
    description='Window:',
    layout=widgets.Layout(width='150px')
)

layout_selector = widgets.ToggleButtons(
    options=[('Overlay', 'overlay'), ('Stacked', 'stacked')],
    value='stacked',
    description='Layout:',
    disabled=False
)

# Plot style selector
plot_style_selector = widgets.ToggleButtons(
    options=[
        ('Line + Scatter', 'line_scatter'),
        ('Line only', 'line'),
        ('Scatter only', 'scatter')
    ],
    value='line_scatter',
    description='Style:',
    disabled=False
)

# Variability band controls
variability_checkbox = widgets.Checkbox(
    value=True,
    description='Show variability band',
    indent=False
)

variability_std_text = widgets.FloatText(
    value=1.0,
    description='Std band:',
    layout=widgets.Layout(width='150px')
)

# -------------------------
# Anomaly detection widgets
# -------------------------
anomaly_method_selector = widgets.ToggleButtons(
    options=[
        ('None', 'none'),
        ('EllipticEnvelope', 'elliptic'),
        ('StdDev-based', 'std'),
        ('CuSum (two-sided)', 'cusum_two')
    ],
    value='none',
    description='Anomaly:',
    disabled=False
)

# Shared window settings (always visible)
obs_window_widget = widgets.BoundedIntText(
    value=60,
    min=5,
    max=365,
    step=1,
    description='Obs win:',
    layout=widgets.Layout(width='160px')
)

det_window_widget = widgets.BoundedIntText(
    value=7,
    min=1,
    max=60,
    step=1,
    description='Det win:',
    layout=widgets.Layout(width='160px')
)

min_obs_widget = widgets.BoundedIntText(
    value=15,
    min=1,
    max=200,
    step=1,
    description='Min obs:',
    layout=widgets.Layout(width='160px')
)

# EllipticEnvelope params: contamination
contamination_text = widgets.FloatText(
    value=0.1,
    description='Contam:',
    layout=widgets.Layout(width='200px')
)

# StdDev-based params: std_factor
std_factor_text = widgets.FloatText(
    value=2.0,
    description='Std factor:',
    layout=widgets.Layout(width='200px')
)

# CuSum two-sided params
alpha_widget = widgets.FloatText(
    value=0.01,
    description='alpha:',
    layout=widgets.Layout(width='160px')
)

short_term_widget = widgets.ToggleButtons(
    options=[('Short', True), ('Long', False)],
    value=True,
    description='Mode:',
    disabled=False
)

alarm_fraction_widget = widgets.FloatText(
    value=0.3,
    description='Alarm frac:',
    layout=widgets.Layout(width='200px')
)

detection_logic_widget = widgets.ToggleButtons(
    options=[('Fraction', 'fraction'), ('Consecutive', 'consecutive')],
    value='fraction',
    description='Logic:',
    disabled=False
)

# Group anomaly-specific widgets into boxes for conditional visibility
elliptic_params_box = VBox([widgets.HTML("<b>EllipticEnvelope params</b>"), contamination_text])
std_params_box = VBox([widgets.HTML("<b>StdDev params</b>"), std_factor_text])
cusum_params_box = VBox([
    widgets.HTML("<b>CuSum params</b>"),
    alpha_widget,
    short_term_widget,
    alarm_fraction_widget,
    detection_logic_widget
])

def _set_visible(box, visible: bool):
    box.layout.display = '' if visible else 'none'

def update_param_visibility(change=None):
    method = anomaly_method_selector.value
    _set_visible(elliptic_params_box, method == 'elliptic')
    _set_visible(std_params_box, method == 'std')
    _set_visible(cusum_params_box, method == 'cusum_two')

# Attach observer and set initial visibility
anomaly_method_selector.observe(update_param_visibility, names='value')
update_param_visibility()

SYMPTOM_ORDER = [
    'flatulence', 'vomiting', 'fever', 'constipation', 'lackOfAppetite',
    'bloodInStool', 'nausea', 'jointPain', 'diarrhea', 'stomachPain',
    'tiredness'
]

norm_symptom = mcolors.Normalize(vmin=0, vmax=5)
cmap_symptom = cm.get_cmap('RdYlGn_r').copy()
cmap_symptom.set_bad(color=(0, 0, 0, 0))

norm_symptom_presence = mcolors.Normalize(vmin=0, vmax=1)
cmap_symptom_presence = LinearSegmentedColormap.from_list(
    "symptom_presence_grey",
    ["#eeeeee", "#7f7f7f"],
    N=256
)
cmap_symptom_presence.set_bad(color=(0, 0, 0, 0))

def parse_symptoms(val):
    if pd.isna(val):
        return None

    if isinstance(val, (list, tuple, set)):
        return [str(x).strip() for x in val if str(x).strip()]

    s = str(val).strip()
    if s.endswith("|"):
        s = s[:-1]

    try:
        parsed = ast.literal_eval(s)
        if isinstance(parsed, (list, tuple, set)):
            return [str(x).strip() for x in parsed if str(x).strip()]
    except Exception:
        pass

    out = []
    for tok in s.split(","):
        c = tok.strip(" []'\"")
        if c:
            out.append(c)
    return out

def anomaly_detection_rolling(df, value_col, contamination=0.1,
                              obs_window_days=30, det_window_days=7, min_obs=10):
    df = df.copy()
    df['anomaly'] = np.nan

    for start in range(len(df) - obs_window_days - det_window_days + 1):
        obs_idx = df.index[start: start + obs_window_days]
        obs_mask = df.loc[obs_idx, value_col].notna()

        if obs_mask.sum() >= min_obs:
            model = EllipticEnvelope(
                contamination=contamination,
                random_state=42,
                support_fraction=1
            )
            model.fit(df.loc[obs_idx[obs_mask], [value_col]])

            det_idx = df.index[start + obs_window_days:
                               start + obs_window_days + det_window_days]
            det_mask = df.loc[det_idx, value_col].notna()
            if det_mask.sum() > 0:
                df.loc[det_idx[det_mask], 'anomaly'] = model.predict(
                    df.loc[det_idx[det_mask], [value_col]]
                )

    return df

def anomaly_detection_std(df, value_col, obs_window_days=30,
                          det_window_days=7, min_obs=10, std_factor=2.0):
    df = df.copy()
    df['anomaly'] = np.nan
    values = df[value_col].values
    n = len(values)

    for start in range(n - obs_window_days - det_window_days + 1):
        obs_slice = slice(start, start + obs_window_days)
        det_slice = slice(start + obs_window_days,
                          start + obs_window_days + det_window_days)

        obs_window = values[obs_slice]
        obs_window = obs_window[~np.isnan(obs_window)]

        if len(obs_window) < min_obs:
            continue

        mean = np.nanmean(obs_window)
        std = np.nanstd(obs_window)

        if not np.isfinite(std) or std == 0:
            continue

        det_values = values[det_slice]
        idx_det = df.index[det_slice]

        diff = np.abs(det_values - mean)
        is_anom = diff > std_factor * std

        valid_mask = ~np.isnan(det_values)
        df.loc[idx_det[valid_mask], 'anomaly'] = 1
        df.loc[idx_det[valid_mask & is_anom], 'anomaly'] = -1

    return df

def anomaly_detection_cusum_two_sided(df, value_col, obs_window_days=60, det_window_days=7,
                                      short_term=True, alpha=0.01, min_obs=15,
                                      alarm_fraction=0.3, detection_logic="fraction"):
    df = df.copy()
    df['cusum_pos'] = np.nan
    df['cusum_neg'] = np.nan
    df['anomaly'] = np.nan
    df = df.sort_index()

    values = df[value_col].values
    n = len(values)

    for start in range(n - obs_window_days - det_window_days + 1):
        obs_slice = slice(start, start + obs_window_days)
        det_slice = slice(start + obs_window_days,
                          start + obs_window_days + det_window_days)

        obs_window = values[obs_slice]
        obs_window = obs_window[~np.isnan(obs_window)]

        if len(obs_window) < min_obs:
            continue

        q = np.nanquantile(obs_window, 0.9 if short_term else 0.99)
        k = 0.5 * q

        null_pos = np.zeros(len(obs_window))
        null_neg = np.zeros(len(obs_window))

        for i in range(1, len(obs_window)):
            x = obs_window[i]
            null_pos[i] = max(0, null_pos[i - 1] + (x - k))
            null_neg[i] = max(0, null_neg[i - 1] + (-x - k))

        threshold_pos = np.nanquantile(null_pos, 1 - alpha)
        threshold_neg = np.nanquantile(null_neg, 1 - alpha)

        det_values = values[det_slice]
        S_pos = np.zeros(len(det_values))
        S_neg = np.zeros(len(det_values))

        for i in range(1, len(det_values)):
            x = det_values[i]
            if np.isnan(x):
                S_pos[i] = S_pos[i - 1]
                S_neg[i] = S_neg[i - 1]
                continue

            S_pos[i] = max(0, S_pos[i - 1] + (x - k))
            S_neg[i] = max(0, S_neg[i - 1] + (-x - k))

        df.loc[df.index[det_slice], 'cusum_pos'] = S_pos
        df.loc[df.index[det_slice], 'cusum_neg'] = S_neg

        alarm_pos = np.where(S_pos > threshold_pos)[0]
        alarm_neg = np.where(S_neg > threshold_neg)[0]

        total_alarm = len(alarm_pos) + len(alarm_neg)
        alarm_ratio = total_alarm / det_window_days

        if detection_logic == "consecutive":
            pos_run = len(alarm_pos) > 1
            neg_run = len(alarm_neg) > 1

            if pos_run or neg_run:
                df.loc[df.index[det_slice], 'anomaly'] = -1
            else:
                df.loc[df.index[det_slice], 'anomaly'] = 1

        elif detection_logic == "fraction":
            if alarm_ratio >= alarm_fraction:
                df.loc[df.index[det_slice], 'anomaly'] = -1
            else:
                df.loc[df.index[det_slice], 'anomaly'] = 1

    return df


# ----------------------------
# Main plotting function
# ----------------------------
def plot_user(user_id, features, smooth, window, layout_mode, plot_style,
              anomaly_method,
              contamination, std_factor,
              obs_window_days, det_window_days, min_obs,
              alpha, short_term, alarm_fraction, detection_logic,
              variability_band, variability_std):

    features = list(features)
    df = merged[merged["user_id"] == user_id].sort_values("date").reset_index(drop=True)

    if df.empty:
        fig, ax = plt.subplots(figsize=(16, 8))
        ax.text(0.5, 0.5, f"No data for user {user_id}",
                ha='center', va='center', transform=ax.transAxes)
        ax.axis("off")
        plt.show()
        return

    n_dates = len(df)
    w = int(window)

    uid_str = str(user_id)
    uid_short = uid_str[:7] + "..." if len(uid_str) > 7 else uid_str

    age = df['age'].dropna().iloc[0] if 'age' in df.columns and df['age'].notna().any() else "NA"
    gender = df['gender'].dropna().iloc[0] if 'gender' in df.columns and df['gender'].notna().any() else "NA"
    diagnosis = df['diagnosis'].dropna().iloc[0] if 'diagnosis' in df.columns and df['diagnosis'].notna().any() else "NA"

    user_title = f"User: {uid_short}, Age: {age}, Gender: {gender}, Diagnosis: {diagnosis}"

    # Parse symptoms
    parsed_symptoms_per_row = []
    present_symptoms = set()

    for val in df["symptoms"]:
        items = parse_symptoms(val)
        parsed_symptoms_per_row.append(items)
        if items is not None:
            for s in items:
                if s in SYMPTOM_ORDER:
                    present_symptoms.add(s)

    all_symptoms = [s for s in SYMPTOM_ORDER if s in present_symptoms]
    n_sym = len(all_symptoms)
    sym_idx = {s: i for i, s in enumerate(all_symptoms)}

    dates = df["date"]
    dates_num = mdates.date2num(dates)

    # symptom_deg
    if "symptom_deg" in df.columns:
        deg_series = pd.to_numeric(df["symptom_deg"], errors="coerce")
    else:
        deg_series = pd.Series([np.nan] * n_dates, index=df.index)

    # flare
    if "rate_as_flare" in df.columns:
        flare_series = df["rate_as_flare"]
    else:
        flare_series = pd.Series([np.nan] * n_dates, index=df.index)

    flare_map = {"No": 1.0, "Unsure": 3.0, "Yes": 4.0}
    flare_numeric = np.full(n_dates, np.nan)
    for i, v in enumerate(flare_series):
        if not pd.isna(v):
            flare_numeric[i] = flare_map.get(v, np.nan)

    # smoothing for symptom_deg and flare
    if smooth and w > 1:
        deg_series = deg_series.rolling(window=w, min_periods=1).mean()

        flare_series_num = pd.Series(flare_numeric, index=df.index)
        flare_series_num = flare_series_num.rolling(window=w, min_periods=1).mean()
        flare_numeric = flare_series_num.to_numpy()

    # ------------
    # feature data
    # ------------
    data = None
    data_std = None

    if len(features) > 0:
        data = df[["date"] + features].copy()
        for f in features:
            data[f] = pd.to_numeric(data[f], errors="coerce")

        if smooth and w > 1:
            rolling_obj = data[features].rolling(window=w, min_periods=1)
            data_std = rolling_obj.std()
            data[features] = rolling_obj.mean()
        else:
            data_std = None

    # ----------------------------
    # Anomaly computation per feature
    # ----------------------------
    anomalies = {}
    if anomaly_method != "none" and data is not None and len(features) > 0:
        for feat in features:
            df_feat = data[["date", feat]].copy()
            df_feat = df_feat.set_index("date")
            df_feat = df_feat.rename(columns={feat: "value"})

            if anomaly_method == "elliptic":
                res = anomaly_detection_rolling(
                    df_feat, value_col="value",
                    contamination=contamination,
                    obs_window_days=obs_window_days,
                    det_window_days=det_window_days,
                    min_obs=min_obs
                )
            elif anomaly_method == "std":
                res = anomaly_detection_std(
                    df_feat, value_col="value",
                    obs_window_days=obs_window_days,
                    det_window_days=det_window_days,
                    min_obs=min_obs,
                    std_factor=std_factor
                )
            elif anomaly_method == "cusum_two":
                res = anomaly_detection_cusum_two_sided(
                    df_feat, value_col="value",
                    obs_window_days=obs_window_days,
                    det_window_days=det_window_days,
                    short_term=short_term,
                    alpha=alpha,
                    min_obs=min_obs,
                    alarm_fraction=alarm_fraction,
                    detection_logic=detection_logic
                )
            else:
                continue

            mask = res["anomaly"] == -1
            if mask.any():
                anomalies[feat] = (res.index[mask], res.loc[mask, "value"])

    deg_height = 0.3
    flare_height = 0.3
    heat_height = 0.3 * n_sym if n_sym > 0 else 1.5

    if layout_mode == "overlay" or len(features) == 0:

        series_height = 7.0
        total_height = heat_height + deg_height + series_height + flare_height

        fig, (ax_heat, ax_deg, ax_feat, ax_flare) = plt.subplots(
            4, 1,
            sharex=True,
            figsize=(16, total_height),
            gridspec_kw={
                "height_ratios": [heat_height, deg_height, series_height, flare_height]
            }
        )

        # Symptom presence
        if n_sym > 0:
            heat_presence = np.full((n_sym, n_dates), np.nan)

            for col, items in enumerate(parsed_symptoms_per_row):
                if items is None:
                    continue
                heat_presence[:, col] = 0.0
                for s in items:
                    if s in sym_idx:
                        heat_presence[sym_idx[s], col] = 1.0

            if smooth and w > 1:
                presence_df = pd.DataFrame(
                    heat_presence.T,
                    index=df.index,
                    columns=all_symptoms
                )
                presence_df = presence_df.rolling(window=w, min_periods=1).mean()
                heat_presence = presence_df.to_numpy().T

            ax_heat.imshow(
                heat_presence,
                aspect="auto",
                interpolation="none",
                origin="lower",
                extent=[dates_num.min() - .5, dates_num.max() + .5, -0.5, n_sym - .5],
                cmap=cmap_symptom_presence,
                norm=norm_symptom_presence
            )
            ax_heat.set_yticks(range(n_sym))
            ax_heat.set_yticklabels(all_symptoms, fontsize=11)
            ax_heat.set_title("Symptoms (True / False)")
            ax_heat.tick_params(axis='x', labelbottom=False)

        else:
            ax_heat.text(0.5, 0.5, "No selected symptoms",
                         ha='center', va='center', transform=ax_heat.transAxes)
            ax_heat.set_yticks([])

        # symptom_deg heatmap
        heat_deg = np.full((1, n_dates), np.nan)
        heat_deg[0, :] = deg_series.values

        ax_deg.imshow(
            heat_deg,
            aspect="auto",
            interpolation="none",
            origin="lower",
            extent=[dates_num.min() - .5, dates_num.max() + .5, -0.5, 0.5],
            cmap=cmap_symptom,
            norm=norm_symptom
        )
        ax_deg.set_yticks([0])
        ax_deg.set_yticklabels(["symptom_deg"], fontsize=11)
        ax_deg.set_title("Symptom Degree (0-5)")
        ax_deg.tick_params(axis='x', labelbottom=False)

        # FEATURE OVERLAY (multi-color)
        if data is None or len(features) == 0:
            ax_feat.text(0.5, 0.5, "No feature selected",
                         ha='center', va='center', transform=ax_feat.transAxes)
        else:
            color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
            axes_list = [ax_feat]
            anomaly_plotted_label = False
            band_label_plotted = False

            for i, feat in enumerate(features):
                if i == 0:
                    ax_i = ax_feat
                else:
                    ax_i = ax_feat.twinx()
                    axes_list.append(ax_i)
                    ax_i.spines["right"].set_position(("outward", 40 * (i - 1)))

                color = color_cycle[i % len(color_cycle)]
                x = data["date"]
                y = data[feat]

                # Variability band when smoothing + checkbox enabled
                if (
                    smooth and w > 1 and
                    variability_band and
                    data_std is not None and
                    feat in data_std.columns
                ):
                    y_std = data_std[feat]
                    label_band = None
                    if not band_label_plotted:
                        label_band = f"±{variability_std} std"
                        band_label_plotted = True
                    ax_i.fill_between(
                        x,
                        y - variability_std * y_std,
                        y + variability_std * y_std,
                        alpha=0.15,
                        color=color,
                        label=label_band
                    )

                # chosen plot style
                if plot_style in ("line", "line_scatter"):
                    ax_i.plot(x, y, color=color, label=feat)
                if plot_style in ("scatter", "line_scatter"):
                    ax_i.scatter(x, y, color=color, s=20)

                # anomalies (red dots)
                if feat in anomalies:
                    anom_x, anom_y = anomalies[feat]
                    label = "Anomaly" if not anomaly_plotted_label else None
                    ax_i.scatter(anom_x, anom_y, color="red", s=50, zorder=5, label=label)
                    if not anomaly_plotted_label:
                        anomaly_plotted_label = True

                ax_i.set_ylabel(feat, color=color, fontsize=12, fontweight="bold")
                ax_i.tick_params(axis="y", labelcolor=color)

            ax_feat.set_title(user_title, fontweight='bold')
            ax_feat.grid(True)
            ax_feat.tick_params(axis='x', labelbottom=True)

            lines, labels = [], []
            for a in axes_list:
                l, lb = a.get_legend_handles_labels()
                lines += l
                labels += lb
            if lines:
                ax_feat.legend(lines, labels, loc='best')

        # Flare heatmap
        heat_flare = np.full((1, n_dates), np.nan)
        heat_flare[0, :] = flare_numeric

        ax_flare.imshow(
            heat_flare,
            aspect="auto",
            interpolation="none",
            origin="lower",
            extent=[dates_num.min() - .5, dates_num.max() + .5, -0.5, 0.5],
            cmap=cmap_symptom,
            norm=norm_symptom
        )
        ax_flare.set_yticks([0])
        ax_flare.set_yticklabels(["rate_as_flare"], fontsize=11)
        ax_flare.set_title('Subjective Flare Assessment ("Yes" / "Unsure" / "No")', fontsize=12)
        ax_flare.tick_params(axis='x', labelbottom=False)

    # ---------- STACKED ---------
    else:
        n_feat = len(features)
        if n_feat == 0:
            n_feat = 1

        feature_height = 2.5
        feature_block_height = feature_height * n_feat
        total_height = heat_height + deg_height + feature_block_height + flare_height

        n_rows = 3 + (len(features) if len(features) > 0 else 1)

        if len(features) > 0:
            feature_heights = [feature_height] * len(features)
        else:
            feature_heights = [feature_height]

        height_ratios = [heat_height, deg_height] + feature_heights + [flare_height]

        fig, axes = plt.subplots(
            n_rows, 1,
            sharex=True,
            figsize=(16, total_height),
            gridspec_kw={"height_ratios": height_ratios}
        )

        ax_heat = axes[0]
        ax_deg = axes[1]
        feature_axes = axes[2:-1]
        ax_flare = axes[-1]

        # Symptom presence
        if n_sym > 0:
            heat_presence = np.full((n_sym, n_dates), np.nan)
            for col, items in enumerate(parsed_symptoms_per_row):
                if items is None:
                    continue
                heat_presence[:, col] = 0.0
                for s in items:
                    if s in sym_idx:
                        heat_presence[sym_idx[s], col] = 1.0

            if smooth and w > 1:
                presence_df = pd.DataFrame(
                    heat_presence.T,
                    index=df.index,
                    columns=all_symptoms
                )
                presence_df = presence_df.rolling(window=w, min_periods=1).mean()
                heat_presence = presence_df.to_numpy().T

            ax_heat.imshow(
                heat_presence,
                aspect="auto",
                interpolation="none",
                origin="lower",
                extent=[dates_num.min() - .5, dates_num.max() + .5, -0.5, n_sym - .5],
                cmap=cmap_symptom_presence,
                norm=norm_symptom_presence
            )
            ax_heat.set_yticks(range(n_sym))
            ax_heat.set_yticklabels(all_symptoms, fontsize=11)
            ax_heat.set_title("Symptoms (True / False)")
            ax_heat.tick_params(axis='x', labelbottom=False)

        else:
            ax_heat.text(0.5, 0.5, "No selected symptoms",
                         ha='center', va='center', transform=ax_heat.transAxes)
            ax_heat.set_yticks([])

        # symptom_deg
        heat_deg = np.full((1, n_dates), np.nan)
        heat_deg[0, :] = deg_series.values

        ax_deg.imshow(
            heat_deg,
            aspect="auto",
            interpolation="none",
            origin="lower",
            extent=[dates_num.min() - .5, dates_num.max() + .5, -0.5, 0.5],
            cmap=cmap_symptom,
            norm=norm_symptom
        )
        ax_deg.set_yticks([0])
        ax_deg.set_yticklabels(["symptom_deg"], fontsize=11)
        ax_deg.set_title("Symptom Degree (0-5)")
        ax_deg.tick_params(axis='x', labelbottom=False)

        # feature stacked plots (single color for all features)
        if data is None or len(features) == 0:
            ax_feat = feature_axes[0]
            ax_feat.text(0.5, 0.5, "No feature selected",
                         ha='center', va='center', transform=ax_feat.transAxes)
            ax_feat.set_yticks([])
        else:
            base_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]

            for i, (ax_f, feat) in enumerate(zip(feature_axes, features)):
                x = data["date"]
                y = data[feat]

                # Variability band when smoothing + checkbox enabled
                if (
                    smooth and w > 1 and
                    variability_band and
                    data_std is not None and
                    feat in data_std.columns
                ):
                    y_std = data_std[feat]
                    ax_f.fill_between(
                        x,
                        y - variability_std * y_std,
                        y + variability_std * y_std,
                        alpha=0.15,
                        color=base_color
                    )

                if plot_style in ("line", "line_scatter"):
                    ax_f.plot(x, y, color=base_color)
                if plot_style in ("scatter", "line_scatter"):
                    ax_f.scatter(x, y, color=base_color, s=20)

                # anomalies (red dots)
                if feat in anomalies:
                    anom_x, anom_y = anomalies[feat]
                    ax_f.scatter(anom_x, anom_y, color="red", s=50, zorder=5)

                ax_f.set_ylabel(feat, fontsize=12, fontweight="bold")
                ax_f.grid(True)

                if i == 0:
                    ax_f.set_title(user_title, fontweight='bold')

                if i < len(feature_axes) - 1:
                    ax_f.tick_params(axis='x', labelbottom=False)

        # flare heatmap
        heat_flare = np.full((1, n_dates), np.nan)
        heat_flare[0, :] = flare_numeric

        ax_flare.imshow(
            heat_flare,
            aspect="auto",
            interpolation="none",
            origin="lower",
            extent=[dates_num.min() - .5, dates_num.max() + .5, -0.5, 0.5],
            cmap=cmap_symptom,
            norm=norm_symptom
        )
        ax_flare.set_yticks([0])
        ax_flare.set_yticklabels(["rate_as_flare"], fontsize=11)
        ax_flare.set_title('Subjective Flare Assessment ("Yes" / "Unsure" / "No")', fontsize=12)
        ax_flare.tick_params(axis='x', labelbottom=True)

    plt.tight_layout()
    plt.show()

# ----------------------------
# Connect interactive output
# ----------------------------
out = interactive_output(
    plot_user,
    dict(
        user_id=user_selector,
        features=feature_selector,
        smooth=smooth_checkbox,
        window=smooth_window,
        layout_mode=layout_selector,
        plot_style=plot_style_selector,
        anomaly_method=anomaly_method_selector,
        contamination=contamination_text,
        std_factor=std_factor_text,
        obs_window_days=obs_window_widget,
        det_window_days=det_window_widget,
        min_obs=min_obs_widget,
        alpha=alpha_widget,
        short_term=short_term_widget,
        alarm_fraction=alarm_fraction_widget,
        detection_logic=detection_logic_widget,
        variability_band=variability_checkbox,
        variability_std=variability_std_text
    )
)

display(
    VBox([
        HBox([
            VBox([
                user_selector,
                data_type_filter,
                min_days_filter
            ]),
            feature_selector,
            VBox([
                smooth_checkbox,
                smooth_window,
                variability_checkbox,
                variability_std_text,
                layout_selector,
                plot_style_selector
            ]),
            VBox([
                anomaly_method_selector,
                obs_window_widget,
                det_window_widget,
                min_obs_widget,
                elliptic_params_box,
                std_params_box,
                cusum_params_box
            ])
        ]),
        out
    ])
)