# Taper Transition Analysis

In [None]:
import os, math
import numpy as np
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt

In [None]:
TABLES_DIRECTORY = "../../Data/giant_tables"
TAPER_THRESHOLD = 0.5

In [None]:
patient_hup_ids = []
# Iterate through all files in TABLES_DIRECTORY
for filename in os.listdir(TABLES_DIRECTORY):
    # Only look at filename that are .csv files and does not begin with .
    if filename.endswith(".csv") and not filename.startswith("."):
        # Get the patient_hup_id from the filename which is after _ and before .
        patient_hup_id = int(filename.split("_")[1].split(".")[0])
        patient_hup_ids.append(patient_hup_id)

patient_hup_ids = sorted(patient_hup_ids)
len(patient_hup_ids)

## Plotting function

In [None]:
def plot_stuff(hourly_patient_features_df, before_taper_period, after_taper_period):
    med_cols = [
        col
        for col in hourly_patient_features_df.columns
        if col.startswith("med_") and not hourly_patient_features_df[col].eq(0).all()
    ]
    med_cols_no_raw = [col.split("_raw")[0] for col in med_cols]

    fig, ax = plt.subplots(13, 1, figsize=(10, 20), sharex=True)

    for i, col in enumerate(med_cols):
        if col != "med_sum_no_lorazepam_raw":
            label = med_cols_no_raw[i].replace("med_", "")
            normalized_data = (
                hourly_patient_features_df[col] / hourly_patient_features_df[col].max()
            )
            ax[0].plot(
                hourly_patient_features_df["emu_hour"], normalized_data, label=label
            )

    # Adding horizontal bars
    ax[0].hlines(
        1, before_taper_period[0], before_taper_period[1], color="green", linewidth=2
    )
    ax[0].hlines(
        1, after_taper_period[0], after_taper_period[1], color="red", linewidth=2
    )

    # Adding text annotations
    ax[0].text(
        (before_taper_period[0] + before_taper_period[1]) / 2,
        1,
        "before taper",
        ha="center",
        va="bottom",
    )
    ax[0].text(
        (after_taper_period[0] + after_taper_period[1]) / 2,
        1,
        "after taper",
        ha="center",
        va="bottom",
    )

    seizure_hours = hourly_patient_features_df[
        hourly_patient_features_df["num_seizures"] >= 1
    ]["emu_hour"].values
    for idx in seizure_hours:
        ax[0].axvline(x=idx, color="red", linestyle="dotted")

    ax[0].set_ylabel("Normalized Load")
    ax[0].set_ylim([0, 1.3])
    ax[0].legend(loc="upper right")
    ax[0].set_title("Individual AEDs")

    ax[1].plot(
        hourly_patient_features_df["emu_hour"],
        hourly_patient_features_df["med_sum_no_lorazepam_raw"],
    )
    ax[1].set_ylabel("Total AED")
    ax[1].set_title("Total AED")

    ax[2].plot(
        hourly_patient_features_df["emu_hour"],
        hourly_patient_features_df["spikes_sum_all"],
    )
    for idx in seizure_hours:
        ax[2].axvline(x=idx, color="red", linestyle="dotted")

    ax[2].set_ylabel("Total Spikes")
    ax[2].set_title("All Spikes")

    # Plotting Teager energy for all frequency bands
    bands = ["delta", "theta", "alpha", "beta", "gamma"]
    for i, band in enumerate(bands):
        ax[i + 3].plot(
            hourly_patient_features_df["emu_hour"],
            hourly_patient_features_df["teager_energy_" + band],
        )
        for idx in seizure_hours:
            ax[i + 3].axvline(x=idx, color="red", linestyle="dotted")

        ax[i + 3].set_ylabel("Teager Energy")
        ax[i + 3].set_title(f"Teager Energy ({band.capitalize()} Band)")

    # Plotting Kuramoto for all frequency bands
    for i, band in enumerate(bands):
        ax[i + 8].plot(
            hourly_patient_features_df["emu_hour"],
            hourly_patient_features_df["kuramoto_" + band],
        )
        for idx in seizure_hours:
            ax[i + 8].axvline(x=idx, color="red", linestyle="dotted")

        ax[i + 8].set_ylabel("R")
        ax[i + 8].set_title(f"Synchrony ({band.capitalize()} Band)")

    ax[12].set_xlabel("Time (hours)")

    plt.tight_layout()
    plt.show()

## Find taper period function

In [None]:
def find_taper_periods(hourly_patient_features_df, patient_hup_id):
    # Drop the 'med_lorazepam_raw' column
    hourly_patient_features_df = hourly_patient_features_df.drop(
        "med_lorazepam_raw", axis=1
    )

    # Find the first 'emu_hour' where both 'teager_energy' and 'kuramoto' are not nan
    start_point = hourly_patient_features_df[
        (~hourly_patient_features_df["teager_energy_delta"].isna())
        & (~hourly_patient_features_df["kuramoto_delta"].isna())
    ].emu_hour.min()

    # Get medication columns
    med_columns = [
        col for col in hourly_patient_features_df.columns if col.startswith("med_")
    ]

    # 2 day periods with 3 day gap
    before_taper_period = (start_point, start_point + 48)
    after_taper_period = (start_point + 48 + 72, start_point + 48 + 72 + 48)
    period_length_days = 2
    gap_length_days = 3

    if (
        hourly_patient_features_df[
            (hourly_patient_features_df["num_seizures"] > 0)
            & (hourly_patient_features_df["emu_hour"] >= after_taper_period[0])
            & (hourly_patient_features_df["emu_hour"] <= after_taper_period[1])
        ].shape[0]
        > 0
    ):
        print("2 day periods with 3 day gap contains seizures")
        # 2 day periods with 2 day gap
        before_taper_period = (start_point, start_point + 48)
        after_taper_period = (start_point + 48 + 48, start_point + 48 + 48 + 48)
        period_length_days = 2
        gap_length_days = 2

    # check if after_taper_period contains seizures
    if (
        hourly_patient_features_df[
            (hourly_patient_features_df["num_seizures"] > 0)
            & (hourly_patient_features_df["emu_hour"] >= after_taper_period[0])
            & (hourly_patient_features_df["emu_hour"] <= after_taper_period[1])
        ].shape[0]
        > 0
    ):
        print("2 day periods with 2 day gap contains seizures")
        # 2 day periods with 1 day gap
        before_taper_period = (start_point, start_point + 48)
        after_taper_period = (start_point + 48 + 24, start_point + 48 + 24 + 48)
        period_length_days, gap_length_days = 2, 1

    # check if after_taper_period contains seizures
    if (
        hourly_patient_features_df[
            (hourly_patient_features_df["num_seizures"] > 0)
            & (hourly_patient_features_df["emu_hour"] >= after_taper_period[0])
            & (hourly_patient_features_df["emu_hour"] <= after_taper_period[1])
        ].shape[0]
        > 0
    ):
        print("2 day periods with 1 day gap contains seizures")
        # 1 day periods with 2 day gap
        before_taper_period = (start_point, start_point + 24)
        after_taper_period = (start_point + 24 + 48, start_point + 24 + 48 + 24)
        period_length_days = 1
        gap_length_days = 1

    # check if after_taper_period contains seizures
    if (
        hourly_patient_features_df[
            (hourly_patient_features_df["num_seizures"] > 0)
            & (hourly_patient_features_df["emu_hour"] >= after_taper_period[0])
            & (hourly_patient_features_df["emu_hour"] <= after_taper_period[1])
        ].shape[0]
        > 0
    ):
        print("1 day periods with 2 day gap contains seizures")
        # 1 day periods with 1 day gap
        before_taper_period = (start_point, start_point + 24)
        after_taper_period = (start_point + 24 + 24, start_point + 24 + 24 + 24)
        period_length_days = 1
        gap_length_days = 1

    # check if the entire two periods contains seizures
    if (
        hourly_patient_features_df[
            (hourly_patient_features_df["num_seizures"] > 0)
            & (hourly_patient_features_df["emu_hour"] >= before_taper_period[0])
            & (hourly_patient_features_df["emu_hour"] <= after_taper_period[1])
        ].shape[0]
        > 0
    ):
        print("1 day periods with 1 day gap contains seizures")
        print("Last straw, both periods contain seizures, discard!")
        return (np.nan, np.nan), (np.nan, np.nan), None, None

    return before_taper_period, after_taper_period, period_length_days, gap_length_days

## Plot all time series

In [None]:
good_hup_ids = []

for patient_hup_id in patient_hup_ids:
    # Read in the giant table for this patient
    hourly_patient_features_df = pd.read_csv(
        os.path.join(TABLES_DIRECTORY, f"HUP_{str(patient_hup_id)}.csv")
    )
    (
        before_taper_period,
        after_taper_period,
        period_length_days,
        gap_length_days,
    ) = find_taper_periods(hourly_patient_features_df, patient_hup_id)

    # If any value in before_taper_period or after_taper_period is nan, skip this patient
    if (
        np.isnan(before_taper_period[0])
        or np.isnan(before_taper_period[1])
        or np.isnan(after_taper_period[0])
        or np.isnan(after_taper_period[1])
    ):
        print(
            f"Patient {patient_hup_id} has nan values in before_taper_period or after_taper_period"
        )
        continue

    print(
        f"Patient {patient_hup_id} before taper period: {before_taper_period} after taper period: {after_taper_period}"
    )
    plot_stuff(hourly_patient_features_df, before_taper_period, after_taper_period)
    good_hup_ids.append(patient_hup_id)

In [None]:
len(good_hup_ids)