### Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import numpy as np
import ast
from ast import literal_eval
from datetime import date
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

Mounted at /content/drive


### Pipeline

In [None]:
class FlarePipeline:
    def __init__(
        self,
        flare_threshold: float = 2.5,
        gap_fill_max: int = 4,
        island_min_len: int = 14,
        smoothing_window: int = 7,
        flare_merge_gap: int = 7,
        min_flare_length: int = 3
    ):
        self.flare_threshold = flare_threshold
        self.gap_fill_max = gap_fill_max
        self.island_min_len = island_min_len
        self.smoothing_window = smoothing_window
        self.flare_merge_gap = flare_merge_gap
        self.min_flare_length = min_flare_length

    # ------------------------------------------------------------------
    # --- Step 1: Data continuity + gap filling ---
    # ------------------------------------------------------------------
    @staticmethod
    def _fill_user_data(group: pd.DataFrame) -> pd.DataFrame:
        user_id = group['user_id'].iloc[0]
        group = group.copy()
        group['date'] = pd.to_datetime(group['date'])

        if group['date'].nunique() <= 1:
            return group

        full_dates = pd.date_range(group['date'].min(), group['date'].max(), freq='D')
        full_df = pd.DataFrame({'user_id': user_id, 'date': full_dates})
        merged = pd.merge(full_df, group, on=['user_id', 'date'], how='left')
        return merged

    def _fill_small_gaps(self, df: pd.DataFrame, col: str) -> pd.DataFrame:
        df = df.copy()
        df['date'] = pd.to_datetime(df['date'])
        df = df.drop_duplicates(subset=['user_id', 'date'])

        filled = []
        for user, group in df.groupby('user_id'):
            group = group.sort_values('date').set_index('date')
            group = group[~group.index.duplicated(keep='first')]

            full_range = pd.date_range(group.index.min(), group.index.max(), freq='D')
            group = group.reindex(full_range)
            group['user_id'] = user

            col_idx = group.columns.get_loc(col)
            i = 0
            while i < len(group):
                if pd.isna(group.iloc[i, col_idx]):
                    start = i
                    while i < len(group) and pd.isna(group.iloc[i, col_idx]):
                        i += 1
                    end = i
                    gap_len = end - start

                    if gap_len <= self.gap_fill_max and start > 0 and end < len(group):
                        if gap_len == 1:
                            group.iloc[start, col_idx] = group.iloc[start-1, col_idx]
                        elif gap_len == 2:
                            group.iloc[start, col_idx] = group.iloc[start-1, col_idx]
                            group.iloc[start+1, col_idx] = group.iloc[end, col_idx]
                        elif gap_len == 3:
                            group.iloc[start:start+2, col_idx] = group.iloc[start-1, col_idx]
                            group.iloc[start+2, col_idx] = group.iloc[end, col_idx]
                        elif gap_len == 4:
                            group.iloc[start:start+2, col_idx] = group.iloc[start-1, col_idx]
                            group.iloc[start+2:end, col_idx] = group.iloc[end, col_idx]
                else:
                    i += 1

            filled.append(group)

        return (
            pd.concat(filled)
            .reset_index()
            .rename(columns={'index': 'date'})
        )

    def _remove_small_islands(self, df: pd.DataFrame, col: str) -> pd.DataFrame:
        result = []
        for user, group in df.groupby('user_id'):
            group = group.sort_values('date').reset_index(drop=True)
            mask = group[col].notna().astype(int)
            streak_id = (mask.ne(mask.shift())).cumsum()

            for _, sub in group.groupby(streak_id):
                if sub[col].notna().all() and len(sub) < self.island_min_len:
                    wipe_cols = [c for c in group.columns if c not in ['user_id', 'date']]
                    group.loc[sub.index, wipe_cols] = pd.NA
            result.append(group)
        return pd.concat(result).reset_index(drop=True)

    # ------------------------------------------------------------------
    # --- Step 2: Smoothing ---
    # ------------------------------------------------------------------
    def _smooth_rate_as_flare(self, df: pd.DataFrame) -> pd.DataFrame:
        df = df.copy()
        df['date'] = pd.to_datetime(df['date'])
        df = df.sort_values(['user_id', 'date'])

        def rolling_mode(series: pd.Series) -> pd.Series:
            values = []
            for i in range(len(series)):
                start = max(0, i - self.smoothing_window + 1)
                window_vals = series[start:i+1].dropna()
                if len(window_vals) > 0:
                    values.append(Counter(window_vals).most_common(1)[0][0])
                else:
                    values.append(pd.NA)
            return pd.Series(values, index=series.index)

        df['rate_as_flare'] = (
            df.groupby('user_id')['rate_as_flare']
              .transform(rolling_mode)
        )
        return df

    # ------------------------------------------------------------------
    # --- Step 3: Flare Annotation ---
    # ------------------------------------------------------------------
    def _connect_flares(self, df: pd.DataFrame, flare_col: str) -> pd.DataFrame:
        df = df.copy()
        df['date'] = pd.to_datetime(df['date'])

        filled = []
        for user, group in df.groupby("user_id"):
            group = group.sort_values("date").reset_index(drop=True)
            flare_idx = group.index[group[flare_col] == True].to_list()

            if not flare_idx:
                filled.append(group)
                continue

            for i in range(len(flare_idx) - 1):
                start = group.loc[flare_idx[i], 'date']
                end = group.loc[flare_idx[i + 1], 'date']
                gap = (end - start).days
                if gap < self.flare_merge_gap:
                    mask = (group['date'] > start) & (group['date'] < end)
                    group.loc[mask, flare_col] = True
            filled.append(group)
        return pd.concat(filled).reset_index(drop=True)

    def _remove_short_flares(self, df: pd.DataFrame, flare_col: str) -> pd.DataFrame:
        result = []
        for user, group in df.groupby("user_id"):
            group = group.sort_values("date").copy()
            group['block'] = (group[flare_col] != group[flare_col].shift()).cumsum()
            block_sizes = group.groupby('block')[flare_col].transform('size')

            group.loc[(group[flare_col] == True) & (block_sizes <= self.min_flare_length), flare_col] = pd.NA
            group = group.drop(columns='block')
            result.append(group)
        return pd.concat(result).reset_index(drop=True)

    # ------------------------------------------------------------------
    # --- Step 4: Flare Grouping & Summary ---
    # ------------------------------------------------------------------
    @staticmethod
    def _assign_flare_groups(df: pd.DataFrame, flare_col: str) -> pd.DataFrame:
        df = df.sort_values("date").copy()
        df["flare_group"] = (
            (df[flare_col] == True) & ((df[flare_col].shift(fill_value=False) != True) |
                                       (df["date"].diff().dt.days > 1))
        ).cumsum()
        df.loc[df[flare_col] != True, "flare_group"] = pd.NA
        return df

    def _process_flares(self, df: pd.DataFrame, objective: pd.DataFrame, flare_col: str) -> pd.DataFrame:
        df = (
            df.groupby("user_id", group_keys=False)
              .apply(self._assign_flare_groups, flare_col=flare_col)
        )

        flares = (
            df.dropna(subset=["flare_group"])
              .groupby(["user_id", "flare_group"])
              .agg(
                  date_flare_onset=("date", "min"),
                  date_flare_end=("date", "max"),
                  flare_length=("date", lambda x: (x.max() - x.min()).days + 1),
              )
              .reset_index()
        )

        objective_cols = [
            'sleep_eff', 'dur_asleep', 'dur_REM', 'REM_pct', 'dur_deep', 'deep_pct', 'dur_light',
            'light_pct', 'dur_awake', 'hrv_rmssd', 'avg_hrv_rmssd', 'std_rmssd', 'cv_rmssd',
            'min_rmssd', 'max_rmssd', 'range_rmssd', 'slope_rmssd', 'hrv_sdnn', 'avg_hrv_sdnn',
            'avg_bpm', 'rhr', 'avg_SpO2', 'avg_breaths'
        ]

        flare_records = []
        for _, row in flares.iterrows():
            mask = (
                (objective["user_id"] == row["user_id"]) &
                (objective["date"].between(row["date_flare_onset"], row["date_flare_end"]))
            )
            sub_obj = objective.loc[mask, objective_cols]
            avg_values = sub_obj.mean(numeric_only=True)

            record = {
                "user_id": row["user_id"],
                "date_flare_onset": row["date_flare_onset"],
                "date_flare_end": row["date_flare_end"],
                "flare_length": row["flare_length"],
                **avg_values.to_dict()
            }
            flare_records.append(record)

        return pd.DataFrame(flare_records)

    # ------------------------------------------------------------------
    # --- Main runner ---
    # ------------------------------------------------------------------
    def run(self, subjective: pd.DataFrame, objective: pd.DataFrame):
        subjective = subjective.groupby('user_id', group_keys=False).apply(self._fill_user_data)
        objective = objective.groupby('user_id', group_keys=False).apply(self._fill_user_data)

        subjective = self._fill_small_gaps(subjective, "symptom_deg")
        subjective = self._fill_small_gaps(subjective, "rate_as_flare")

        subjective = self._remove_small_islands(subjective, col="symptom_deg")
        subjective = self._remove_small_islands(subjective, col="rate_as_flare")

        subjective['symptom_deg'] = (
            subjective.groupby('user_id')['symptom_deg']
            .transform(lambda x: x.rolling(window=self.smoothing_window, min_periods=1).mean())
        )
        subjective = self._smooth_rate_as_flare(subjective)

        subjective['symptom_flare'] = np.where(
            subjective['symptom_deg'].notna(),
            subjective['symptom_deg'] >= self.flare_threshold,
            pd.NA
        )
        subjective['flare'] = np.where(
            subjective['rate_as_flare'].notna(),
            subjective['rate_as_flare'] == "Yes",
            pd.NA
        )

        subjective = self._connect_flares(subjective, flare_col="symptom_flare")
        subjective = self._connect_flares(subjective, flare_col="flare")

        subjective = self._remove_short_flares(subjective, flare_col="flare")
        subjective = self._remove_short_flares(subjective, flare_col="symptom_flare")

        subjective = subjective.sort_values(["user_id", "date"])
        objective = objective.sort_values(["user_id", "date"])

        summary_symptom_flares = self._process_flares(subjective, objective, flare_col="symptom_flare")
        summary_regular_flares = self._process_flares(subjective, objective, flare_col="flare")

        return subjective, objective, summary_symptom_flares, summary_regular_flares

In [None]:
# 1. Load data
subjective = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/subjective.csv')
objective = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/objective.csv')

# Make sure date column is datetime
objective['date'] = pd.to_datetime(objective['date'])
subjective['date'] = pd.to_datetime(subjective['date'])

# 2. Initialize pipeline (with default parameters, or override if needed)
pipeline = FlarePipeline(
    flare_threshold=2.5,   # threshold for symptom_deg
    gap_fill_max=4,        # max gap length to fill
    island_min_len=14,     # min streak length to keep
    smoothing_window=7,    # smoothing window
    flare_merge_gap=7,     # max gap to merge flares
    min_flare_length=3     # remove flares shorter than this
)

# 3. Run pipeline
subjective_proc, objective_proc, summary_symptom_flares, summary_regular_flares = pipeline.run(subjective, objective)

### Save Data

In [None]:
summary_symptom_flares.to_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/summary_symptom_flares.csv', index=False)
summary_regular_flares.to_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/summary_regular_flares.csv', index=False)
subjective_proc.to_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/subjective_flare_annotated.csv', index=False)

### Plots

In [None]:
def plot_user_data(
    df: pd.DataFrame,
    value_col: str,
    max_users: int = 50,
    spacing_between_users: float = 3,
    cmap=None,
    norm=None,
    categorical_map: dict = None,
    markersize: int = 2
):

    # Sort users by number of records
    user_counts = df['user_id'].value_counts(ascending=True)
    sorted_users = user_counts.index.tolist()[-max_users:]

    height = len(sorted_users) * 0.18
    width = 14
    fig, ax = plt.subplots(figsize=(width, height))

    for i, user in enumerate(sorted_users):
        user_data = df[df['user_id'] == user].reset_index(drop=True)
        base_y = i * spacing_between_users

        for j, row in user_data.iterrows():
            val = row[value_col]

            # Determine color
            if pd.isna(val):
                color = (0, 0, 0, 0)  # transparent for NaN
            elif categorical_map is not None:
                color = categorical_map.get(val, (0, 0, 0, 0))
            elif cmap is not None and norm is not None:
                color = cmap(norm(val))
            else:
                color = (0, 0, 0, 1)  # fallback black

            ax.plot(j, base_y, 'o', color=color, markersize=markersize)

    # Y-axis labels = user IDs
    ytick_positions = [i * spacing_between_users for i in range(len(sorted_users))]
    ax.set_yticks(ytick_positions)
    ax.set_yticklabels(sorted_users, fontsize=5)

    ax.set_ylim(-2, (len(sorted_users) - 1) * spacing_between_users + 2)
    ax.set_xlabel("Days since first record")
    ax.set_ylabel("User ID")

    plt.show()

#### Plots

In [None]:
norm_symptom = mcolors.Normalize(vmin=0, vmax=5)
cmap_symptom = cm.get_cmap('RdYlGn_r')

plot_user_data(subjective_proc, value_col='symptom_deg', cmap=cmap_symptom, norm=norm_symptom)

In [None]:
flare_color = cm.get_cmap('RdYlGn_r')(mcolors.Normalize(vmin=0, vmax=5)(4))

plot_user_data(subjective_proc, value_col='symptom_flare', categorical_map={True: flare_color, False: "lightgray"})

In [None]:
norm_symptom = mcolors.Normalize(vmin=0, vmax=5)
cmap_symptom = cm.get_cmap('RdYlGn_r')

categorical_map = {"Yes": cmap_symptom(norm_symptom(4)), "No": cmap_symptom(norm_symptom(1)), "Unsure": cmap_symptom(norm_symptom(3)), np.nan: (0,0,0,0)}

plot_user_data(subjective_proc, value_col='rate_as_flare', categorical_map=categorical_map)

In [None]:
flare_color = cm.get_cmap('RdYlGn_r')(mcolors.Normalize(vmin=0, vmax=5)(4))

plot_user_data(subjective_proc, value_col='flare', categorical_map={True: flare_color, False: "lightgray"})

#### Comparison to Hirten: symptom_deg Flare

In [None]:
flare_color = cm.get_cmap('RdYlGn_r')(mcolors.Normalize(vmin=0, vmax=5)(4))

plot_user_data(subjective_proc, value_col='symptom_flare', max_users = 300, spacing_between_users = 2, markersize = 5, categorical_map={True: flare_color, False: "lightgray"})

#### Comparison to Hirten: rate_as_flare Flare

In [None]:
flare_color = cm.get_cmap('RdYlGn_r')(mcolors.Normalize(vmin=0, vmax=5)(4))

plot_user_data(subjective_proc, value_col='flare', max_users = 300, spacing_between_users = 2, markersize = 5,categorical_map={True: flare_color, False: "lightgray"})