In [None]:
import glob
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import signal, stats
from tqdm import tqdm

MOVEMENT_LIST = [
    "Rest",
    "MassFlexion",
    "HookGrasp",
    "ThumbAdduction",
    "PinchGrasp",
    "PinchGraspMiddle",
    "PinchGraspRing",
    "PinchGraspPinkie",
    "DiameterGrasp",
    "SphereGrasp",
    "MassAdduction",
    "MassExtension",
    "WristVolarFlexion",
    "WristDorsiFlexion",
    "ForearmPronation",
    "ForearmSupination",
]

MOVEMENT_COLORS = dict(zip(MOVEMENT_LIST, plt.cm.rainbow(np.linspace(0, 1, len(MOVEMENT_LIST)))))
DATA_DIR = Path(os.getenv("PHYSIOMIO_DATA_DIR", "data"))

all_parquet_files = glob.glob(str(DATA_DIR / "**/*.parquet"), recursive=True)
assert len(all_parquet_files) > 0, "No parquet files found under data"
print(f"Found {len(all_parquet_files)} parquet files under {DATA_DIR}")

In [None]:
def calculate_channel_mean_amplitude(
        df: pd.DataFrame,
        channel: str,
        movement_type_col: str = 'movement_type') -> tuple[dict[str, float], float, float]:
    """
    Calculate SNR statistics for a single channel using mean amplitude.

    Args:
        df: DataFrame containing the EMG data and movement labels
        channel: Name of the channel column to analyze
        movement_type_col: Name of the column containing movement labels

    Returns:
        tuple containing:
            - Dictionary mapping movement names to SNR values
            - Mean SNR across all movements (excluding Rest)
            - Standard deviation of SNR across movements
    """
    unique_movements = df[movement_type_col].unique()
    mean_amps = {}
    snr_values = {}

    # Calculate Rest mean amplitude first
    rest_mask = df[movement_type_col] == 'Rest'
    assert 'Rest' in unique_movements, "Rest movement must be present in the data"

    rest_ma = df.loc[rest_mask, channel].abs().mean()
    mean_amps['Rest'] = rest_ma

    # Calculate other mean amplitudes and SNR
    for movement in unique_movements:
        if movement == 'Rest':
            continue
        mask = df[movement_type_col] == movement
        ma = df.loc[mask, channel].abs().mean()
        mean_amps[movement] = ma
        # Using 20*log10 for amplitude ratio
        snr_values[movement] = 20 * np.log10(ma / rest_ma)

    # Calculate average SNR and std dev
    snr_values_list = list(snr_values.values())
    avg_snr = float(np.mean(snr_values_list))
    std_snr = float(np.std(snr_values_list))

    return snr_values, avg_snr, std_snr

def calculate_overall_snr(df: pd.DataFrame, snr_calculator,
                          movement_type_col: str = 'movement_type') -> tuple[float, float, list[float], float]:
    """
    Calculate overall SNR across all channels

    Args:
        df: DataFrame containing the EMG data and movement labels
        snr_calculator: Function to calculate SNR statistics
        movement_type_col: Name of the column containing movement labels

    Returns:
        tuple containing:
            - Mean SNR across all channels and movements
            - Standard deviation of SNR across all channels and movements
            - List of all individual SNR values
            - Average mean amplitude across all channels
    """
    all_snr_values = []
    all_mean_amplitudes = []

    for i in range(64):
        channel_name = f'channel_{i+1:02d}'
        if channel_name in df.columns:
            # Calculate SNR for this channel
            snr_values, _, _ = snr_calculator(
                df, channel_name, movement_type_col)
            # Add all individual SNR values for this channel
            all_snr_values.extend(list(snr_values.values()))
            
            # Calculate mean amplitude for this channel across all data
            channel_mean_amplitude = df[channel_name].abs().mean()
            all_mean_amplitudes.append(channel_mean_amplitude)

    overall_mean_snr = float(np.mean(all_snr_values))
    overall_std_snr = float(np.std(all_snr_values))
    avg_mean_amplitude = float(np.mean(all_mean_amplitudes)) if all_mean_amplitudes else 0.0
    
    return overall_mean_snr, overall_std_snr, all_snr_values, avg_mean_amplitude

In [None]:
def get_detailed_snrs_df(parquet_files: list[str]):
    """
    Calculate detailed SNR statistics with one row per recording.

    Args:
        parquet_files: List of paths to parquet files containing EMG data

    Returns:
        pd.DataFrame: Detailed statistics with one row per recording showing
                     Patient, Recording, Arm Type, SNR Mean, and SNR Std Dev
    """
    rows = []

    for file in tqdm(parquet_files, desc="Processing recordings"):
        # Extract patient number
        patient_num = int(file.split('/patient')[1].split('/')[0])
        
        # Extract recording number from filename
        recording_name = Path(file).stem  # Gets filename without extension
        
        # Determine arm type
        if 'healthy_arm' in file:
            arm_type = 'Healthy'
        elif 'impaired_arm' in file:
            arm_type = 'Impaired'
        else:
            arm_type = 'Unknown'
        
        # Calculate SNR for this recording
        df = pd.read_parquet(file)
        snr_mean, snr_std, _, avg_mean_amplitude = calculate_overall_snr(
            df, snr_calculator=calculate_channel_mean_amplitude)
        
        rows.append({
            'Patient': patient_num,
            'Recording': recording_name,
            'Arm Type': arm_type,
            'SNR Mean (dB)': round(snr_mean, 2),
            'SNR Std Dev (dB)': round(snr_std, 2),
            'Avg Mean Amplitude': round(avg_mean_amplitude, 6)
        })

    # Create DataFrame and sort by Patient, then by Arm Type, then by Recording
    df_detailed = pd.DataFrame(rows)
    df_detailed = df_detailed.sort_values(['Patient', 'Arm Type', 'Recording'])
    
    return df_detailed.style.set_properties(**{'text-align': 'center'}).set_table_styles([
        {'selector': 'th', 'props': [('text-align', 'center')]},
        {'selector': '', 'props': [('border', '1px solid black')]},
        {'selector': 'th,td', 'props': [('padding', '8px')]}
    ])


#all_snrs_df = get_patient_snrs_df(all_parquet_files)
#display(all_snrs_df)

# Display the new detailed table with one row per recording
print("\n" + "="*80)
print("DETAILED SNR TABLE - ONE ROW PER RECORDING")
print("="*80)
detailed_snrs_df = get_detailed_snrs_df(all_parquet_files)
display(detailed_snrs_df)

# Save the SNR table to CSV file
snr_csv_filename = 'detailed_snrs_table.csv'
# Convert styled DataFrame back to regular DataFrame for CSV saving
detailed_snrs_df_plain = detailed_snrs_df.data
detailed_snrs_df_plain.to_csv(snr_csv_filename, index=False)
print(f"\nSNR table saved to: {snr_csv_filename}")

print(f"\nSNR analysis complete!")
print(f"- Processed {len(all_parquet_files)} files")
print(f"- Generated table with {len(detailed_snrs_df_plain)} recordings")
print(f"- Saved detailed results to {snr_csv_filename}")

In [None]:
def preprocess_emg(data: np.ndarray, fs: float) -> np.ndarray:
    notch_freq = 50.0
    quality_factor = 10.0
    b_notch, a_notch = signal.iirnotch(notch_freq, quality_factor, fs)
    filtered_data = signal.filtfilt(b_notch, a_notch, data)

    high_pass_freq = 20.0
    b_high, a_high = signal.butter(4, high_pass_freq/(fs/2), btype='high')
    filtered_data = signal.filtfilt(b_high, a_high, filtered_data)

    return filtered_data


def compute_psd(data: np.ndarray, fs: float, nperseg: int = 1024) -> tuple[np.ndarray, np.ndarray]:
    nperseg = min(nperseg, len(data) // 2)
    if nperseg < 256:
        nperseg = 256

    f, Pxx = signal.welch(data, fs, nperseg=nperseg, scaling="density")
    return f, Pxx

# Fixed version of the PSD function
def get_detailed_psd_df_fixed(parquet_files: list[str], num_freq_bins: int = 40):
    """
    Calculate detailed PSD statistics with one row per recording (FIXED VERSION).
    
    Args:
        parquet_files: List of paths to parquet files containing EMG data
        num_freq_bins: Number of frequency bins to create for PSD features
        
    Returns:
        pd.DataFrame: Detailed PSD statistics with one row per recording showing
                     Patient, Recording, Arm Type, and PSD values in frequency bins
    """
    rows = []
    fs = 2048
    
    # Define frequency bins (20-500 Hz)
    freq_min, freq_max = 20, 500
    freq_bins = np.linspace(freq_min, freq_max, num_freq_bins + 1)
    
    # Create column names for frequency bins
    freq_columns = []
    for i in range(num_freq_bins):
        freq_start = freq_bins[i]
        freq_end = freq_bins[i + 1]
        freq_columns.append(f'PSD_{freq_start:.1f}-{freq_end:.1f}Hz')
    
    for file in tqdm(parquet_files, desc="Processing PSD for recordings"):
        # Extract patient number
        patient_num = int(file.split('/patient')[1].split('/')[0])
        
        # Extract recording number from filename
        recording_name = Path(file).stem
        
        # Determine arm type
        if 'healthy_arm' in file:
            arm_type = 'Healthy'
        elif 'impaired_arm' in file:
            arm_type = 'Impaired'
        else:
            arm_type = 'Unknown'
        
        # Read data and calculate PSD for all channels
        df = pd.read_parquet(file)
        channel_psds = []
        
        for channel in df.columns:
            if channel.startswith('channel_') and channel != 'channel_49':
                channel_data = df[channel].values * 0.001
                processed_emg = preprocess_emg(channel_data, fs)
                f, Pxx = compute_psd(processed_emg, fs, nperseg=1024)
                channel_psds.append(Pxx)  # Just store PSD, f is same for all
        
        if not channel_psds:
            continue
            
        # Get frequency array (same for all channels)
        channel_data_temp = df[df.columns[df.columns.str.startswith('channel_')][0]].values * 0.001
        processed_temp = preprocess_emg(channel_data_temp, fs)
        frequencies, _ = compute_psd(processed_temp, fs, nperseg=1024)
        
        # Average PSD across all channels for this recording
        avg_psd = np.mean(channel_psds, axis=0)
        
        # Calculate power in each frequency bin
        psd_features = {}
        for i, col_name in enumerate(freq_columns):
            freq_start = freq_bins[i]
            freq_end = freq_bins[i + 1]
            
            # Find indices for this frequency range
            freq_mask = (frequencies >= freq_start) & (frequencies < freq_end)
            
            if np.any(freq_mask):
                # Calculate average power in this frequency bin
                bin_power = np.mean(avg_psd[freq_mask])
                psd_features[col_name] = bin_power  # Keep full precision
            else:
                psd_features[col_name] = 0.0
        
        # Create row with basic info and PSD features
        row = {
            'Patient': patient_num,
            'Recording': recording_name,
            'Arm Type': arm_type,
        }
        row.update(psd_features)
        rows.append(row)
    
    # Create DataFrame and sort
    df_psd = pd.DataFrame(rows)
    df_psd = df_psd.sort_values(['Patient', 'Arm Type', 'Recording'])
    
    return df_psd


# Test the fixed function with just a few files
print("\\n" + "="*80)
print("TESTING FIXED PSD FUNCTION")
print("="*80)

test_files = all_parquet_files[:3]  # Just test with 3 files
psd_df_fixed = get_detailed_psd_df_fixed(test_files, num_freq_bins=10)  # Fewer bins for testing

print(f"Fixed PSD table shape: {psd_df_fixed.shape}")
print(f"\\nPSD columns: {[col for col in psd_df_fixed.columns if col.startswith('PSD_')]}")

# Show the actual values with full precision
print("\\nFirst row PSD values (showing actual numbers):")
first_row = psd_df_fixed.iloc[0]
for col in psd_df_fixed.columns:
    if col.startswith('PSD_'):
        print(f"  {col}: {first_row[col]:.8e}")

# Display the table
display(psd_df_fixed)

In [None]:
# Calculate PSD table for the whole dataset
print("\n" + "="*80)
print("CALCULATING PSD TABLE FOR WHOLE DATASET")
print("="*80)

# Use 20 frequency bins to balance detail vs computational efficiency
psd_df_complete = get_detailed_psd_df_fixed(all_parquet_files, num_freq_bins=20)

print(f"Complete PSD table shape: {psd_df_complete.shape}")
print(f"Number of recordings processed: {len(psd_df_complete)}")

# Display basic statistics about the table
print(f"\nDataset breakdown:")
print(f"- Total recordings: {len(psd_df_complete)}")
print(f"- Healthy arm recordings: {len(psd_df_complete[psd_df_complete['Arm Type'] == 'Healthy'])}")
print(f"- Impaired arm recordings: {len(psd_df_complete[psd_df_complete['Arm Type'] == 'Impaired'])}")
print(f"- Number of patients: {psd_df_complete['Patient'].nunique()}")

# Show frequency bin columns
psd_columns = [col for col in psd_df_complete.columns if col.startswith('PSD_')]
print(f"\nFrequency bins created: {len(psd_columns)}")
print("Frequency ranges:")
for i, col in enumerate(psd_columns):
    print(f"  {i+1:2d}. {col}")

# Display the table
print("\n" + "="*80)
print("DETAILED PSD TABLE - ALL RECORDINGS")
print("="*80)
display(psd_df_complete)

In [None]:
# Save the PSD table to CSV file
psd_csv_filename = 'detailed_psd_table.csv'
psd_df_complete.to_csv(psd_csv_filename, index=False)
print(f"\nPSD table saved to: {psd_csv_filename}")

# Show summary statistics for the PSD data
print("\n" + "="*40)
print("PSD DATA SUMMARY STATISTICS")
print("="*40)

# Calculate summary statistics for each frequency bin, split by arm type
summary_stats = []
for col in psd_columns:
    healthy_values = psd_df_complete[psd_df_complete['Arm Type'] == 'Healthy'][col]
    impaired_values = psd_df_complete[psd_df_complete['Arm Type'] == 'Impaired'][col]
    
    summary_stats.append({
        'Frequency_Range': col.replace('PSD_', ''),
        'Healthy_Mean': f"{healthy_values.mean():.2e}",
        'Healthy_Std': f"{healthy_values.std():.2e}",
        'Impaired_Mean': f"{impaired_values.mean():.2e}",
        'Impaired_Std': f"{impaired_values.std():.2e}",
        'Ratio_Impaired/Healthy': f"{impaired_values.mean()/healthy_values.mean():.3f}"
    })

psd_summary_df = pd.DataFrame(summary_stats)
display(psd_summary_df)

print(f"\nPSD analysis complete!")
print(f"- Processed {len(all_parquet_files)} files")
print(f"- Generated table with {len(psd_df_complete)} recordings")
print(f"- Created {len(psd_columns)} frequency bins from 20-500 Hz")
print(f"- Saved detailed results to {psd_csv_filename}")


In [None]:
# Create FMA scores table
def get_detailed_fma_df(parquet_files: list[str]):
    """
    Extract FMA scores with one row per recording.
    
    Args:
        parquet_files: List of paths to parquet files containing EMG data
        
    Returns:
        pd.DataFrame: Detailed FMA scores with one row per recording showing
                     Patient, Recording, Arm Type, and Average FMA Score
    """
    rows = []
    
    for file in tqdm(parquet_files, desc="Processing FMA scores"):
        # Extract patient number
        patient_num = int(file.split('/patient')[1].split('/')[0])
        
        # Extract recording number from filename
        recording_name = Path(file).stem
        
        # Determine arm type
        if 'healthy_arm' in file:
            arm_type = 'Healthy'
        elif 'impaired_arm' in file:
            arm_type = 'Impaired'
        else:
            arm_type = 'Unknown'
        
        # Read data and extract FMA score
        df = pd.read_parquet(file)
        
        # Calculate average FMA score for this recording (excluding NaN values)
        fma_values = df['fma'].dropna()
        if len(fma_values) > 0:
            avg_fma = fma_values.mean()
        else:
            avg_fma = None  # No FMA data available
        
        rows.append({
            'Patient': patient_num,
            'Recording': recording_name,
            'Arm Type': arm_type,
            'Average FMA Score': round(avg_fma, 2) if avg_fma is not None else None
        })
    
    # Create DataFrame and sort
    df_fma = pd.DataFrame(rows)
    df_fma = df_fma.sort_values(['Patient', 'Arm Type', 'Recording'])
    
    return df_fma

# Calculate FMA scores table
print("\n" + "="*80)
print("CALCULATING FMA SCORES TABLE")
print("="*80)

fma_df = get_detailed_fma_df(all_parquet_files)

print(f"FMA scores table shape: {fma_df.shape}")
print(f"Number of recordings processed: {len(fma_df)}")

# Display basic statistics about the FMA scores
print(f"\nDataset breakdown:")
print(f"- Total recordings: {len(fma_df)}")
print(f"- Healthy arm recordings: {len(fma_df[fma_df['Arm Type'] == 'Healthy'])}")
print(f"- Impaired arm recordings: {len(fma_df[fma_df['Arm Type'] == 'Impaired'])}")
print(f"- Number of patients: {fma_df['Patient'].nunique()}")

# Show FMA score statistics
fma_scores_available = fma_df['Average FMA Score'].dropna()
print(f"\nFMA Score Statistics:")
print(f"- Recordings with FMA scores: {len(fma_scores_available)}")
print(f"- Recordings without FMA scores: {len(fma_df) - len(fma_scores_available)}")

if len(fma_scores_available) > 0:
    print(f"- Mean FMA score: {fma_scores_available.mean():.2f}")
    print(f"- Min FMA score: {fma_scores_available.min():.2f}")
    print(f"- Max FMA score: {fma_scores_available.max():.2f}")
    print(f"- Std FMA score: {fma_scores_available.std():.2f}")

# Display the table
print("\n" + "="*80)
print("DETAILED FMA SCORES TABLE - ALL RECORDINGS")
print("="*80)
display(fma_df)

# Save the FMA table to CSV file
fma_csv_filename = 'detailed_fma_table.csv'
fma_df.to_csv(fma_csv_filename, index=False)
print(f"\nFMA scores table saved to: {fma_csv_filename}")

print(f"\nFMA analysis complete!")
print(f"- Processed {len(all_parquet_files)} files")
print(f"- Generated table with {len(fma_df)} recordings")
print(f"- Saved detailed results to {fma_csv_filename}")


In [None]:
# Create CCN (Correlation Coefficient of Normality) table
def calculate_ccn(data: np.ndarray) -> float:
    """
    Calculate the Correlation Coefficient of Normality (CCN).
    
    CCN is the Pearson correlation between the histogram of EMG signal amplitudes 
    and a normal distribution with the same mean and variance.
    A CCN close to 1 indicates the signal's amplitude distribution is near-Gaussian.
    
    Args:
        data: EMG signal data
        
    Returns:
        CCN value (correlation coefficient)
    """
    hist, bin_edges = np.histogram(data, bins='auto', density=True)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    mean = np.mean(data)
    std = np.std(data)
    normal_pdf = stats.norm.pdf(bin_centers, mean, std)

    return np.corrcoef(hist, normal_pdf)[0, 1]


def get_detailed_ccn_df(parquet_files: list[str]):
    """
    Calculate detailed CCN statistics with one row per recording.
    
    Args:
        parquet_files: List of paths to parquet files containing EMG data
        
    Returns:
        pd.DataFrame: Detailed CCN statistics with one row per recording showing
                     Patient, Recording, Arm Type, and Average CCN across all channels
    """
    rows = []
    
    for file in tqdm(parquet_files, desc="Processing CCN for recordings"):
        # Extract patient number
        patient_num = int(file.split('/patient')[1].split('/')[0])
        
        # Extract recording number from filename
        recording_name = Path(file).stem
        
        # Determine arm type
        if 'healthy_arm' in file:
            arm_type = 'Healthy'
        elif 'impaired_arm' in file:
            arm_type = 'Impaired'
        else:
            arm_type = 'Unknown'
        
        # Read data and calculate CCN for all channels
        df = pd.read_parquet(file)
        channel_ccns = []
        
        for channel in df.columns:
            if channel.startswith('channel_'):
                signal_data = df[channel].values
                ccn = calculate_ccn(signal_data)
                channel_ccns.append(ccn)
        
        # Calculate average CCN across all channels for this recording
        if channel_ccns:
            avg_ccn = np.mean(channel_ccns)
            std_ccn = np.std(channel_ccns)
        else:
            avg_ccn = None
            std_ccn = None
        
        rows.append({
            'Patient': patient_num,
            'Recording': recording_name,
            'Arm Type': arm_type,
            'Average CCN': round(avg_ccn, 4) if avg_ccn is not None else None,
            'CCN Std Dev': round(std_ccn, 4) if std_ccn is not None else None
        })
    
    # Create DataFrame and sort
    df_ccn = pd.DataFrame(rows)
    df_ccn = df_ccn.sort_values(['Patient', 'Arm Type', 'Recording'])
    
    return df_ccn


# Calculate CCN table for the whole dataset
print("\\n" + "="*80)
print("CALCULATING CCN TABLE FOR WHOLE DATASET")
print("="*80)

ccn_df = get_detailed_ccn_df(all_parquet_files)

print(f"CCN table shape: {ccn_df.shape}")
print(f"Number of recordings processed: {len(ccn_df)}")

# Display basic statistics about the CCN values
print(f"\\nDataset breakdown:")
print(f"- Total recordings: {len(ccn_df)}")
print(f"- Healthy arm recordings: {len(ccn_df[ccn_df['Arm Type'] == 'Healthy'])}")
print(f"- Impaired arm recordings: {len(ccn_df[ccn_df['Arm Type'] == 'Impaired'])}")
print(f"- Number of patients: {ccn_df['Patient'].nunique()}")

# Show CCN statistics
ccn_values_available = ccn_df['Average CCN'].dropna()
print(f"\\nCCN Statistics:")
print(f"- Recordings with CCN values: {len(ccn_values_available)}")

if len(ccn_values_available) > 0:
    print(f"- Mean CCN: {ccn_values_available.mean():.4f}")
    print(f"- Min CCN: {ccn_values_available.min():.4f}")
    print(f"- Max CCN: {ccn_values_available.max():.4f}")
    print(f"- Std CCN: {ccn_values_available.std():.4f}")
    
    # Compare healthy vs impaired
    healthy_ccn = ccn_df[ccn_df['Arm Type'] == 'Healthy']['Average CCN'].dropna()
    impaired_ccn = ccn_df[ccn_df['Arm Type'] == 'Impaired']['Average CCN'].dropna()
    
    if len(healthy_ccn) > 0 and len(impaired_ccn) > 0:
        print(f"\\nComparison by Arm Type:")
        print(f"- Healthy arm mean CCN: {healthy_ccn.mean():.4f} ± {healthy_ccn.std():.4f}")
        print(f"- Impaired arm mean CCN: {impaired_ccn.mean():.4f} ± {impaired_ccn.std():.4f}")

# Display the table
print("\\n" + "="*80)
print("DETAILED CCN TABLE - ALL RECORDINGS")
print("="*80)
display(ccn_df)

# Save the CCN table to CSV file
ccn_csv_filename = 'detailed_ccn_table.csv'
ccn_df.to_csv(ccn_csv_filename, index=False)
print(f"\\nCCN table saved to: {ccn_csv_filename}")

print(f"\\nCCN analysis complete!")
print(f"- Processed {len(all_parquet_files)} files")
print(f"- Generated table with {len(ccn_df)} recordings")
print(f"- Saved detailed results to {ccn_csv_filename}")
