In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import decimate
from gait_events import gait_events_HC_JA
from dtaidistance import dtw
from tqdm import tqdm
from tslearn.metrics import dtw
import json

Install h5py to use hdf5 features: http://docs.h5py.org/
  warn(h5py_msg)


In [11]:
#2. Function to segment and downsample gait cycles
def segment_downsamp(df,
                   signal_col='Ankle Dorsiflexion RT (deg)',
                   min_length=20,# make sure this is long enough to avoid spurious segments
                   downsample_factor=4):
    """
    1. Detect heel strikes on the original 200 Hz data.
    2. Slice out each gait cycle between strikes.
    3. Enforce a minimum length to avoid spurious segments.
    4. Normalize each cycle (zero-mean, unit-variance).
    5. Downsample each cycle from 200 Hz to 50 Hz.
    """
    # 1) Detect right heel strikes
    hs_R, _, _, _ = gait_events_HC_JA(df)
    series = df[signal_col].interpolate().values

    cycles = []
    # Sort the strike indices
    hs = sorted(int(i) for i in hs_R)

    # 2) & 3) Extract raw cycles
    for start, end in zip(hs, hs[1:]):
        cycle = series[start:end]
        if len(cycle) >= min_length:
            # Only downsample if the cycle length is large enough
            if len(cycle) > 27:  # 27 to handle padding requirements in decimate
                # Apply Z-score normalization
                cycle = (cycle - np.mean(cycle)) / np.std(cycle)
                # Downsample to 50 Hz
                cycle = decimate(cycle, downsample_factor, zero_phase=True)
                cycles.append(cycle)
            else:
                # Skip cycles too short for downsampling
                print(f"Skipping cycle of length {len(cycle)} (too short for downsampling).")
                
    return cycles

In [15]:
# === 1. Load patient data and define output folder ==
def load_patient_data(patient_folder, patient_id):
    """
    Read all CSV files for a given patient and return a list of DataFrames and file paths.
    """
    df_list, file_list = [], []
    # Iterate over directory structure D01/D02 × B01–B03 × T01–T03
    for d in ['D01', 'D02']:
        for b in ['B01', 'B02', 'B03']:
            for t in ['T01', 'T02', 'T03']:
                fname = os.path.join(patient_folder, f"{patient_id}_G01_{d}_{b}_{t}.csv")
                if os.path.exists(fname):
                    df = pd.read_csv(fname)
                    if not df.empty:
                        df_list.append(df)
                        file_list.append(fname)
                else:
                    print(f"File not found: {fname}")
    return df_list, file_list

# === Main analysis parameters ===
base_folder = './young adults (19–35 years old)'

output_base = './DTW_downsampled2'

# Create output directory if it does not exist
os.makedirs(output_base, exist_ok=True)

# Identify patient IDs in the base folder
patient_ids = sorted([f for f in os.listdir(base_folder) if f.startswith('S') and len(f) == 4])
print('Detected patients:', patient_ids)

summary = []

Detected patients: ['S002', 'S003', 'S004', 'S005', 'S006', 'S007', 'S008', 'S010', 'S012', 'S014', 'S015', 'S016', 'S017', 'S018', 'S020', 'S021', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S030', 'S031', 'S032', 'S033', 'S034', 'S035', 'S036', 'S037', 'S038', 'S039', 'S063']


In [12]:
# 3. DTW Iterate over patients with a progress bar
# Top‐level bar over patients
dtw_all = {}
for pid in tqdm(patient_ids, desc='Patients', position=0):
    tqdm.write(f"Processing patient {pid}...")
    folder = os.path.join(base_folder, pid)
    dfs, files = load_patient_data(folder, pid)
    if not dfs:
        tqdm.write(f"  No data for {pid}, skipping.")
        continue

    # 1) Representative cycle per trial
    rep_cycles, trial_names = [], []

    # Bar over trials
    for df, fname in tqdm(zip(dfs, files),
                          total=len(dfs),
                          desc=f'Trials {pid}',
                          position=1,
                          leave=False):
        cycles = segment_downsamp(df)
        for idx1, c1 in enumerate(cycles):
            for idx2, c2 in enumerate(cycles):
                if not np.array_equal(c1, c2):
                    dtw_score = dtw(c1, c2)
                    if pid not in dtw_all:
                        dtw_all[pid] = {}
                    if f'{idx1}|{idx2}' not in dtw_all[pid]:
                        dtw_all[pid][f'{idx1}|{idx2}'] = []
                    dtw_all[pid][f'{idx1}|{idx2}'].append(dtw_score)

Patients:   0%|          | 0/1 [00:00<?, ?it/s]

Processing patient S002...




Skipping cycle of length 21 (too short for downsampling).




Skipping cycle of length 24 (too short for downsampling).




Skipping cycle of length 26 (too short for downsampling).


Patients: 100%|██████████| 1/1 [01:42<00:00, 102.74s/it]


In [3]:
json.dump(dtw_all, open(os.path.join(output_base, 'dtw_all.json')))

NameError: name 'output_base' is not defined

In [4]:
file_path = 'DTW_downsampled2/dtw_all.json'
with open(file_path, 'r') as f:
    dtw_all = json.load(f)

with open(file_path, 'w') as f:
    json.dump(dtw_all, f)  

In [None]:


# Cargar los resultados desde el archivo JSON
with open("DTW_downsampled2/dtw_all.json", "r") as f:
    dtw_all = json.load(f)

# Función para calcular la media de DTW por trial (usando un ciclo representativo)
def calculate_representative_cycle_dtw(patient_dtw):
    # Crear un dataframe para almacenar los resultados DTW entre trials
    trials = list(patient_dtw.keys())  # Ejemplo: ["0|1", "0|2", "0|3"]
    
    dtw_values = []
    
    # Iterar por cada par de trials
    for trial_1 in trials:
        for trial_2 in trials:
            if trial_1 != trial_2:
                # Extraer las distancias DTW de la comparación
                dtw_values.append(np.mean(patient_dtw[trial_1]))  # Usamos la media de los DTW entre ciclos del trial
    
    return np.mean(dtw_values)

# Análisis para cada paciente
def analyze_patient_dtw(patient_id, dtw_all):
    patient_dtw = dtw_all.get(patient_id, {})
    
    if not patient_dtw:
        print(f"No data for patient {patient_id}")
        return
    
    print(f"Analyzing patient {patient_id}...")

    # Calcular las distancias DTW promedio entre todos los trials para el paciente
    dtw_values = []
    for trial_1 in patient_dtw.keys():
        for trial_2 in patient_dtw.keys():
            if trial_1 != trial_2:
                # Tomamos la distancia media entre ciclos comparados para cada trial
                dtw_values.append(np.mean(patient_dtw[trial_1]))  # Usamos la media de DTW entre ciclos
    
    # Crear un dataframe de los resultados
    df_dtw = pd.DataFrame(dtw_values, columns=['DTW'])
    return df_dtw

# Visualización: crear un gráfico de barras por paciente
def plot_dtw_analysis(patient_dtw_df, patient_id):
    if patient_dtw_df is not None:
        plt.figure(figsize=(10, 6))
        plt.hist(patient_dtw_df['DTW'], bins=20, color='skyblue', edgecolor='black')
        plt.title(f"DTW Distribution for Patient {patient_id}")
        plt.xlabel('DTW Distance')
        plt.ylabel('Frequency')
        plt.grid(True)
        plt.show()

# Función principal para iterar sobre todos los pacientes
def analyze_all_patients(dtw_all):
    for patient_id in dtw_all.keys():
        # Analizamos el DTW por paciente
        patient_dtw_df = analyze_patient_dtw(patient_id, dtw_all)
        
        # Graficamos el análisis
        plot_dtw_analysis(patient_dtw_df, patient_id)

# Llamamos la función para analizar todos los pacientes
analyze_all_patients(dtw_all)


Analyzing patient S002...
