# This notebook
1. Selects the patient, trial, variable
2. Load the corresponding DataFrame 



In [1]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plotly
import plotly.graph_objs as go
import plotly.io as pio
from matplotlib.widgets import Slider
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from Data_loader import base_folders  
from pipeline import downsample_df, segment_downsamp 
from segment_utils import segment_cycles, segment_cycles_norm, segment_cycles_simple

## From Dataframes 

Preprocessing options: 
    version:
    Raw: The original 48000 datapoints
    Downsampled: by a factor of 4 saved 
    Trimmed: after downsampling removed the first and last 10s 

    Normalized: normalizes (Z-score) after the chosen of raw, downsampled or trimmed 
    Segmented: after chosen raw, downsampled or trimmed  and if normalized or not uses segment_cycles

### Functions

In [2]:
#Function to plot one trial
def load_signal(patient_id, group_code, day, block, trial, signal_col, version):
    base_folder = os.path.abspath(base_folders[group_code])
    patient_folder = os.path.join(base_folder, patient_id)
    fname = f"{patient_id}_{group_code}_{day}_{block}_{trial}.csv"
    
    if version == 'raw':
        path = os.path.join(patient_folder, fname)
        if not os.path.exists(path):
            print(f"[ERROR] File not found: {path}")
            return None
        df = pd.read_csv(path)
        signal = df[signal_col].values
        print(f"[INFO] Loaded raw signal from: {path}")
        return signal, df
    
    elif version == 'downsampled':
        folder = os.path.join(patient_folder, 'trimmed')
        path = os.path.join(folder, fname)
        if not os.path.exists(path):
            print(f"[ERROR] Downsampled file not found: {path}")
            return None
        df = pd.read_csv(path)
        df_ds = downsample_df(df, rate=4)
        signal_ds = df_ds[signal_col].values
        print(f"[INFO] Loaded and downsampled signal from: {path}")
        return signal_ds, df_ds
        
    
    elif version == 'trimmed':
        trimmed_folder = os.path.join(patient_folder, 'trimmed')
        path = os.path.join(trimmed_folder, fname)
        if not os.path.exists(path):
            print(f"[ERROR] Trimmed file not found: {path}")
            return None
        df = pd.read_csv(path)
        signal = df[signal_col].values
        print(f"[INFO] Loaded trimmed signal from: {path}")
        return signal, df
    
    else:
        print(f"[ERROR] Unknown version: {version}")
        return None

def plot_trial(
    patient_id,
    group_code,
    day,
    block,
    trial,
    signal_col="Ankle Dorsiflexion RT (deg)",
    version='raw',  
    normalized=False,
    segmented=False,
    save_plots=False,
    results_dir="Results visualization"
):
    
    signal, df = load_signal(patient_id, group_code, day, block, trial, signal_col, version)
    

    
    if normalized:
        mean = np.mean(signal)
        std = np.std(signal)
        if std == 0:
            print("[WARN] Standard deviation is zero; normalization skipped.")
            signal_norm = signal - mean
        else:
            signal_norm = (signal - mean) / std
        signal_to_plot = signal_norm
        print(f"[INFO] Normalized signal with mean={mean:.2f}, std={std:.2f}")
    else:
        signal_to_plot = signal

    
    if segmented:              
        if normalized:
            cycles = segment_cycles_norm(df)
            if not cycles:
                print("[WARN] No cycles found for segmentation.")
                return
        else:
            cycles = segment_cycles_simple(df)
            if not cycles:
                print("[WARN] No cycles found for segmentation.")
                return
        

        # Graficar ciclos segmentados
        fig = go.Figure()
        for idx, c in enumerate(cycles):
            sig = c[signal_col]
            fig.add_trace(go.Scatter(
                y=sig.values,
                mode='lines',
                name=f'Cycle {idx+1}',
                opacity=0.5
            ))
        fig.update_layout(
            title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} (Segmented)",
            xaxis_title='Sample (resampled)',
            yaxis_title='Angle (deg)',
            xaxis=dict(rangeslider=dict(visible=True), type='linear')
        )
        if save_plots:
            project_root = os.getcwd()
            results_folder = os.path.join(project_root, results_dir)
            os.makedirs(results_folder, exist_ok=True)
            fname_seg = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_segmented.html"
            path_seg = os.path.join(results_folder, fname_seg)
            pio.write_html(fig, file=path_seg, auto_open=False)
        fig.show()
    
    else:
        label = version.capitalize()
        if normalized:
            label = "Normalized"
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=signal_to_plot, mode='lines', name=label))
        fig.update_layout(
            title=f"{patient_id}_{group_code}_{day}_{block}_{trial} | {signal_col} ({label})",
            xaxis_title='Sample',
            yaxis_title='Angle (deg)' if not normalized else 'Z-score normalized Angle',
            xaxis=dict(rangeslider=dict(visible=True), type='linear')
        )
        if save_plots:
            project_root = os.getcwd()
            results_folder = os.path.join(project_root, results_dir)
            os.makedirs(results_folder, exist_ok=True)
            fname = f"{patient_id}_{group_code}_{day}_{block}_{trial}_{signal_col}_{label.lower()}.html"
            path = os.path.join(results_folder, fname)
            pio.write_html(fig, file=path, auto_open=False)
        fig.show()



### Usage

In [3]:
signal_col = "Ankle Dorsiflexion RT (deg)"
group_code = "G03"
patient_id = "S019"
day = "D02"
block = "B01"
trial = "T01"


plot_trial(patient_id, group_code, day, block, trial, signal_col, version='trimmed', normalized=False, segmented=True, save_plots=False, results_dir='Results visualization')

plot_trial(patient_id, group_code, day, block, trial, signal_col, version='trimmed', normalized=False, segmented=False, save_plots=False, results_dir='Results visualization')




[INFO] Loaded trimmed signal from: /mnt/storage/dmartinez/old adults (56+ years old)/S019/trimmed/S019_G03_D02_B01_T01.csv


[INFO] Loaded trimmed signal from: /mnt/storage/dmartinez/old adults (56+ years old)/S019/trimmed/S019_G03_D02_B01_T01.csv


In [4]:
signal_col = "Ankle Dorsiflexion RT (deg)"
group_code = "G01"
patient_id = "S003"
day = "D01"
block = "B01"
trial = "T01"


plot_trial(patient_id, group_code, day, block, trial, signal_col, version='trimmed', normalized=False, segmented=False, save_plots=False, results_dir='Results visualization')


[INFO] Loaded trimmed signal from: /home/dmartinez/Gait-Stability/young adults (19–35 years old)/S003/trimmed/S003_G01_D01_B01_T01.csv


## From Tensors 

In [None]:
# Visualize the tensor organization 

tensor_path = "/home/dmartinez/Gait-Stability/young adults (19–35 years old)/S039/tensors/S039_D01_B03_T02_tensor.npy"
csv_path = "/home/dmartinez/Gait-Stability/young adults (19–35 years old)/S039/trimmed/S039_G01_D01_B03_T02.csv"

# 2) LOAD THE TENSOR AND CSV
tensor_3d = np.load(tensor_path)  # Shape: (n_cycles, target_length, n_variables)
df_trial = pd.read_csv(csv_path)

# 3) GET THE INDEX OF "Ankle Dorsiflexion RT (deg)"
variable_name = "Ankle Dorsiflexion RT (deg)"
if variable_name not in df_trial.columns:
    raise ValueError(f"The variable '{variable_name}' is not in the CSV columns.")
col_idx = df_trial.columns.get_loc(variable_name)
# Now col_idx is an integer between 0 and 320 (since there are 321 columns)

# 4) EXTRACT DATA FOR THAT VARIABLE
# tensor_3d has shape (183, 100, 321), so tensor_3d[:, :, col_idx] will be (183, 100)
data_variable = tensor_3d[:, :, col_idx]

# 5) CREATE A NORMALIZED TIME AXIS (0–100%)
n_cycles, target_length = data_variable.shape
time_axis = np.linspace(0, 100, target_length)

# 6) PLOT ALL CYCLES (183 CURVES)
plt.figure(figsize=(10, 6))
for i in range(n_cycles):
    plt.plot(time_axis, data_variable[i, :], alpha=0.25, linewidth=1)

plt.title(f"All Cycles – {variable_name}")
plt.xlabel("Normalized Cycle (%)")
plt.ylabel(variable_name)
plt.grid(True)
plt.tight_layout()
plt.show()