# This notebook
1. Selects the patient, trial, variable
2. Load the corresponding DataFrame
3. Preprocessing options: Raw, Downsampled, Segmented, Normalized 


In [2]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plotly
import plotly.graph_objs as go
import plotly.io as pio
from matplotlib.widgets import Slider
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from Data_loader import base_folders  
from pipeline import segment_downsamp 
from segment_utils import segment_cycles, segment_cycles_norm
from downsample import downsample_df

## Functions

In [3]:
#Function to plot one trial 
def plot_trial(
    patient_id,
    group_code,
    day,
    block,
    trial,
    signal_col=signal_col,
    min_length=min_length,
    downsample_factor=downsample_factor,
    show_raw=False,
    show_normalized=False,
    show_downsampled=False,
    show_segmented=False,    
    show_segmented_norm=False,
    save_plots=False,
    results_dir=results_dir
):
    if save_plots:
        project_root = os.getcwd()
        results_folder = os.path.join(project_root, results_dir)
        os.makedirs(results_folder, exist_ok=True)

    base_folder = os.path.abspath(base_folders[group_code])
    patient_folder = os.path.join(base_folder, patient_id)
    fname = f"{patient_id}_{group_code}_{day}_{block}_{trial}.csv"
    path = os.path.join(patient_folder, fname)
    print(f"[INFO] File Path: {path}")
    if not os.path.exists(path):
        print(f"[ERROR] File not found: {path}")
        return

    df = pd.read_csv(path)
    signal = df[signal_col].interpolate().values

    if show_raw:
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=signal, mode='lines', name='Raw', line=dict(color='gray')))
        fig.update_layout(
            title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} (Raw)",
            xaxis_title='Sample',
            yaxis_title='Angle (deg)',
            xaxis=dict(rangeslider=dict(visible=True), type="linear")
        )
        if save_plots:
            fname_raw = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_raw.html"
            path_raw = os.path.join(results_folder, fname_raw)
            pio.write_html(fig, file=path_raw, auto_open=False)
        fig.show()

    if show_normalized:
        mean = np.mean(signal)
        std = np.std(signal)
        if std == 0:
            print("[WARN] Standard deviation is zero; normalization skipped.")
            signal_norm = signal - mean  # just mean centering
        else:
            signal_norm = (signal - mean) / std

        fig = go.Figure()
        fig.add_trace(go.Scatter(y=signal_norm, mode='lines', name='Normalized (Z-score)'))
        fig.update_layout(
            title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} (Normalized Z-score)",
            xaxis_title='Sample',
            yaxis_title='Z-score normalized Angle',
            xaxis=dict(rangeslider=dict(visible=True), type="linear")
        )
        if save_plots:
            fname_norm = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_normalized_zscore.html"
            path_norm = os.path.join(results_folder, fname_norm)
            pio.write_html(fig, file=path_norm, auto_open=False)
        fig.show()

    if show_downsampled:
        try:
            signal_ds_df = downsample_df(df, rate=downsample_factor, cols=[signal_col], zero_phase=False)
            signal_ds = signal_ds_df[signal_col].values  

            fig = go.Figure()
            fig.add_trace(go.Scatter(y=signal_ds, mode='lines', name='Downsampled', line=dict(color='orange')))
            fig.update_layout(
                title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} (Downsampled)",
                xaxis_title='Sample',
                yaxis_title='Angle (deg)',
                xaxis=dict(rangeslider=dict(visible=True), type="linear")
            )
            if save_plots:
                fname_ds = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_downsampled.html"
                path_ds = os.path.join(results_folder, fname_ds)
                pio.write_html(fig, file=path_ds, auto_open=False)
            fig.show()

        except Exception as e:
            print(f"[WARN] Downsampling failed: {e}")

    if show_segmented:
        cycles = segment_cycles(df)
        if not cycles:
            print("Not enough cycles found for segmentation.")
        else:
            fig = go.Figure()
            for idx, c in enumerate(cycles):
                sig = c[signal_col]
                fig.add_trace(go.Scatter(
                    y=sig.values,
                    mode='lines',
                    name=f'Cycle {idx+1}',
                    opacity=0.5
                ))
            fig.update_layout(
                title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} (Segmented)",
                xaxis_title='Sample (resampled)',
                yaxis_title='Angle (deg)',
                xaxis=dict(rangeslider=dict(visible=True), type='linear')
            )
            if save_plots:
                fname_seg = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_segmented.html"
                path_seg = os.path.join(results_folder, fname_seg)
                pio.write_html(fig, file=path_seg, auto_open=False)
            fig.show()

    if show_segmented_norm:
        cycles = segment_cycles_norm(df)
        if not cycles:
            print("Not enough cycles found for segmentation.")
        else:
            fig = go.Figure()
            for idx, c in enumerate(cycles):
                sig = c[signal_col]
                fig.add_trace(go.Scatter(
                    y=sig.values,
                    mode='lines',
                    name=f'Cycle {idx+1}',
                    opacity=0.5
                ))
            fig.update_layout(
                title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} (Segmented_Norm)",
                xaxis_title='Sample (resampled)',
                yaxis_title='Angle (deg)',
                xaxis=dict(rangeslider=dict(visible=True), type='linear')
            )
            if save_plots:
                fname_seg_norm = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_segmented_norm.html"
                path_seg_norm = os.path.join(results_folder, fname_seg_norm)
                pio.write_html(fig, file=path_seg_norm, auto_open=False)
            fig.show()


In [29]:
#Function to plot multiple trials for comparison
def plot_multiple_trials(
    patient_id,
    group_code,
    trials_info, 
    signal_col= signal_col,
    downsample_factor=downsample_factor,
    show_raw=show_raw,
    show_downsampled=show_downsampled,
    show_segmented=show_segmented,
    save_plots=save_plots,
    results_dir=results_dir
):
    """
    Plots multiple trials on the same graph for comparison purposes.
    """    
    if save_plots:
        project_root = os.getcwd()
        results_folder = os.path.join(project_root, results_dir)
        os.makedirs(results_folder, exist_ok=True)

    # Initialize separate figures for raw, downsampled, and segmented
    if show_raw:
        fig_raw = go.Figure()

    if show_downsampled:
        fig_downsampled = go.Figure()

    if show_segmented:
        fig_segmented = go.Figure()

    # Define a list of colors to use (can be extended if necessary)
    colors = ['gray', 'orange', 'blue', 'green', 'red', 'purple', 'cyan', 'magenta', 'yellow']
    color_idx = 0  # To cycle through the colors for each trial

    # Iterate over each trial info
    for trial_info in trials_info:
        base_folder = os.path.abspath(base_folders[group_code])
        patient_folder = os.path.join(base_folder, patient_id)
        fname = f"{patient_id}_{group_code}_{trial_info['day']}_{trial_info['block']}_{trial_info['trial']}.csv"
        path = os.path.join(patient_folder, fname)
        
        if not os.path.exists(path):
            print(f"[ERROR] File not found: {path}")
            continue
        
        df = pd.read_csv(path)
        signal = df[signal_col].interpolate().values
        
        # Raw signal
        if show_raw:
            fig_raw.add_trace(go.Scatter(
                y=signal, mode='lines', name=f'Raw Day {trial_info["day"]}, Block {trial_info["block"]}, Trial {trial_info["trial"]}',
                line=dict(color=colors[color_idx], dash='solid')
            ))

        # Downsampled signal
        if show_downsampled:
            try:
                signal_ds_df = downsample_df(df, rate=downsample_factor, cols=[signal_col], zero_phase=False)
                signal_ds = signal_ds_df[signal_col].values
                fig_downsampled.add_trace(go.Scatter(
                    y=signal_ds, mode='lines', name=f'DownSp Day {trial_info["day"]}, Block {trial_info["block"]}, Trial {trial_info["trial"]}',
                    line=dict(color=colors[color_idx], dash='dot')
                ))
            except Exception as e:
                print(f"[WARN] Downsampling failed for trial {trial_info['trial']}: {e}")

        # Segmented Cycles (if segmented)
        if show_segmented:
            cycles = segment_cycles(df)
            if cycles:
                for idx, c in enumerate(cycles):
                    sig = c[signal_col]
                    fig_segmented.add_trace(go.Scatter(
                        y=sig.values, mode='lines', name=f'Raw Trial {trial_info["trial"]} - Day {trial_info["day"]}, Block {trial_info["block"]}',
                        line=dict(color=colors[color_idx], opacity=0.5)
                    ))

        # Cycle color index
        color_idx = (color_idx + 1) % len(colors)  # Cycle through colors

    # Update layout for each figure
    if show_raw:
        fig_raw.update_layout(
            title=f"Raw Signal Comparison for Patient {patient_id} | {signal_col}",
            xaxis_title='Sample',
            yaxis_title='Angle (deg)',
            xaxis=dict(
                rangeslider=dict(visible=True),  # Horizontal scroll
                type="linear"
            ),
            legend=dict(
                orientation='h',  # Horizontal legend
                yanchor='bottom', # Position at the bottom
                y=-0.2,           # Move legend below the graph
                xanchor='center', # Center the legend
                x=0.5             # Position the legend horizontally
            )
        )

    if show_downsampled:
        fig_downsampled.update_layout(
            title=f"Downsampled Signal Comparison for Patient {patient_id} | {signal_col}",
            xaxis_title='Sample',
            yaxis_title='Angle (deg)',
            xaxis=dict(
                rangeslider=dict(visible=True),
                type="linear"
            ),
            legend=dict(
                orientation='h',  # Horizontal legend
                yanchor='bottom', # Position at the bottom
                y=-0.2,           # Move legend below the graph
                xanchor='center', # Center the legend
                x=0.5             # Position the legend horizontally
            )
            
        )

    if show_segmented:
        fig_segmented.update_layout(
            title=f"Segmented Cycles Comparison for Patient {patient_id} | {signal_col}",
            xaxis_title='Sample (resampled)',
            yaxis_title='Angle (deg)',
            xaxis=dict(
                rangeslider=dict(visible=True),
                type='linear'
            ),
            legend=dict(
                orientation='h',  # Horizontal legend
                yanchor='bottom', # Position at the bottom
                y=-0.2,           # Move legend below the graph
                xanchor='center', # Center the legend
                x=0.5             # Position the legend horizontally
            )
            
        )

    # Save each plot separately if requested
    if save_plots:
        if show_raw:
            fname_raw = f"{patient_id}_{group_code}_comparison_{signal_col}_raw.html"
            path_raw = os.path.join(results_folder, fname_raw)
            pio.write_html(fig_raw, file=path_raw, auto_open=False)

        if show_downsampled:
            fname_ds = f"{patient_id}_{group_code}_comparison_{signal_col}_downsampled.html"
            path_ds = os.path.join(results_folder, fname_ds)
            pio.write_html(fig_downsampled, file=path_ds, auto_open=False)

        if show_segmented:
            fname_seg = f"{patient_id}_{group_code}_comparison_{signal_col}_segmented.html"
            path_seg = os.path.join(results_folder, fname_seg)
            pio.write_html(fig_segmented, file=path_seg, auto_open=False)

    # Show each figure
    if show_raw:
        fig_raw.show()

    if show_downsampled:
        fig_downsampled.show()

    if show_segmented:
        fig_segmented.show()

## Usage

In [10]:
# Individual Trial visualization 
signal_col = "Ankle Dorsiflexion RT (deg)"
group_code = "G03"
patient_id = "S013"
day = "D01"
block = "B01"
trial = "T02"
min_length = 20
downsample_factor = 4
show_raw = True
show_normalized = True
show_downsampled = False
show_segmented = True
show_segmented_norm = True
save_plots = False
results_dir = "Results visualization"


plot_trial(
        patient_id=patient_id,
        group_code=group_code,
        day=day,
        block=block,
        trial=trial,
        signal_col=signal_col,
        min_length=min_length,
        downsample_factor=downsample_factor,
        show_raw=show_raw,
        show_normalized=show_normalized,
        show_downsampled=show_downsampled,
        show_segmented=show_segmented,
        show_segmented_norm=show_segmented_norm,
        save_plots=save_plots
)

[INFO] File Path: /home/dmartinez/Gait-Stability/old adults (56+ years old)/S013/S013_G03_D01_B01_T02.csv


In [41]:
#To plot multiple trials for comparison
batches = {
    "G01": {
        "S027": [
            {"day": "D01", "block": "B02", "trial": "T01"},
            {"day": "D02", "block": "B03", "trial": "T03"},
                                           
            
        ]
    },
}

signal_col = "Ankle Dorsiflexion RT (deg)"
downsample_factor = 4
show_raw = True
show_downsampled = False
show_segmented = False
save_plots = False
results_dir = "Results visualization"

for group_code, patients in batches.items():
    for patient_id, trials in patients.items():
        plot_multiple_trials(
            patient_id=patient_id,
            group_code=group_code,
            trials_info=trials,
            signal_col=signal_col,
            downsample_factor=downsample_factor,
            show_raw=show_raw,
            show_downsampled=show_downsampled,
            show_segmented=show_segmented,
            save_plots=save_plots
        )

In [None]:
#To compare results in separate plots 
signal_col        = "Ankle Dorsiflexion RT (deg)"
downsample_factor = 4
show_raw          = True
show_downsampled  = True
show_segmented    = True
save_plots        = False

batches = {
    "G01": {
        "S002": [
            {"trial": "T01", "day": "D01", "block": "B01"},
            {"trial": "T03", "day": "D01", "block": "B03"},
            
        ]
        
    },
    #"G03": {
        #"S139": [
            #{"trial": "T01", "day": "D01", "block": "B01"},
            
        #]
    #},
}


for group_code, patients in batches.items():
    for patient_id, trials in patients.items():
        for trial_info in trials:
            plot_trial(
                patient_id=patient_id,
                group_code=group_code,
                day=trial_info["day"],
                block=trial_info["block"],
                trial=trial_info["trial"],
                signal_col=signal_col,
                downsample_factor=downsample_factor,
                show_raw=show_raw,
                show_downsampled=show_downsampled,
                show_segmented=show_segmented,
                save_plots=save_plots
            )



[INFO] File Path: /home/dmartinez/Gait-Stability/young adults (19–35 years old)/S002/S002_G01_D01_B01_T01.csv


[INFO] File Path: /home/dmartinez/Gait-Stability/young adults (19–35 years old)/S002/S002_G01_D01_B03_T03.csv
