# ***** Adjust Input and output folders at the end of each module after if __name__****
tip: for ease, keep the names the same and control H to replace the path when you do change it so any other of the same path change as well. 

# Emg Load and Parse

In [None]:
import os
import glob
import pandas as pd
import numpy as np
from io import StringIO
import matplotlib.pyplot as plt


########################################
# Function: Read sensor data with metadata and update column names with sensor group and mode identifiers
########################################
def read_sensor_data_with_metadata(file_path, debug=False):
    """
    Reads the sensor CSV file that contains metadata in the first five lines,
    a header in line 5, a sample rate row in line 6, and sensor data from line 7 onward.
    
    For updated data (with multiple sensors: e.g., FDS, FCU, FCR), it parses:
      - Line 3: Sensor group identifiers (e.g., "FDS (81770), , ... , FCU (81728), ..., FCR (81745)")
      - Line 4: Sensor mode information for each group.
    
    It then updates the header (line 5) by appending the sensor group to each column name.
    The sensor mode is loaded into metadata but not appended to the column name.
    
    Parameters:
      file_path (str): Path to the CSV file.
      debug (bool): If True, prints detailed debug output.
    
    Returns:
      df (pd.DataFrame): DataFrame with sensor data, updated column names, metadata columns, and a "Timestamp" column.
      metadata (dict): Dictionary of parsed metadata.
    """
    with open(file_path, 'r') as f:
        all_lines = f.readlines()
    
    metadata = {}
    # --- Parse first 3 lines (common for both formats) ---
    # Line 0: Application
    line = all_lines[0].strip()
    if ',' in line:
        key, value = line.split(',', 1)
        metadata[key.strip().rstrip(':')] = value.strip()
    else:
        metadata['Application'] = line

    # Line 1: Date/Time
    line = all_lines[1].strip()
    if ',' in line:
        key, value = line.split(',', 1)
        metadata[key.strip().rstrip(':')] = value.strip()
    else:
        metadata['Date/Time'] = line

    # Line 2: Collection Length (seconds)
    line = all_lines[2].strip()
    if ',' in line:
        key, value = line.split(',', 1)
        metadata[key.strip().rstrip(':')] = value.strip()
    else:
        metadata['Collection Length (seconds)'] = line

    # --- Determine dataset type (updated or legacy) ---
    sensor_group_line = all_lines[3].strip()
    sensor_mode_line = all_lines[4].strip()
    if ',' in sensor_group_line and len(sensor_group_line.split(',')) > 1:
        # Updated dataset detected
        sensor_group_tokens = [token.strip() for token in sensor_group_line.split(',')]
        # Propagate non-empty values forward
        sensor_groups = []
        last = None
        for token in sensor_group_tokens:
            if token:
                last = token
            sensor_groups.append(last if last is not None else "")
        # Similarly for sensor modes (line 4)
        sensor_mode_tokens = [token.strip() for token in sensor_mode_line.split(',')]
        sensor_modes = []
        last_mode = None
        for token in sensor_mode_tokens:
            if token:
                last_mode = token
            sensor_modes.append(last_mode if last_mode is not None else "")
        # Store these in metadata
        metadata['SensorGroups'] = sensor_groups
        metadata['SensorModes'] = sensor_modes
        if debug:
            print(f"[read_sensor_data_with_metadata] SensorGroups: {sensor_groups}")
            print(f"[read_sensor_data_with_metadata] SensorModes: {sensor_modes}")
    else:
        # Legacy dataset: use line 3 and 4 as single values.
        metadata['Sensor'] = sensor_group_line
        metadata['Sensor Mode'] = sensor_mode_line

    # --- Header row for sensor data is on line 5 in both cases ---
    header_line = all_lines[5].strip()
    original_col_names = [col.strip() for col in header_line.split(',')]
    
    # If updated dataset, update column names by appending only sensor group.
    if 'SensorGroups' in metadata:
        if len(metadata['SensorGroups']) >= len(original_col_names):
            new_col_names = []
            for i, col in enumerate(original_col_names):
                group = metadata['SensorGroups'][i]
                new_col_names.append(f"{col} - {group}")
            if debug:
                print("[read_sensor_data_with_metadata] New column names set (updated dataset, sensor group only).")
        else:
            if debug:
                print("[read_sensor_data_with_metadata] Warning: Not enough sensor group entries; using original column names.")
            new_col_names = original_col_names
    else:
        new_col_names = original_col_names

    # Read the sensor data (starting at line 7)
    data_str = ''.join(all_lines[7:])
    df = pd.read_csv(StringIO(data_str), header=None, names=new_col_names)

    # Add metadata columns to the DataFrame (except sensor group and mode lists)
    for key, value in metadata.items():
        if key not in ['SensorGroups', 'SensorModes']:
            df[key] = value

    # Create a running Timestamp column
    collection_length = float(metadata.get('Collection Length (seconds)', 0))
    start_time = pd.to_datetime(metadata.get('Date/Time', None))
    num_samples = len(df)
    time_offsets = np.linspace(0, collection_length, num_samples)
    df['Timestamp'] = start_time + pd.to_timedelta(time_offsets, unit='s')
    
    if debug:
        print(f"[read_sensor_data_with_metadata] Final DataFrame shape: {df.shape}")
        print(f"[read_sensor_data_with_metadata] Final column names: {df.columns.tolist()}")
    else:
        print("read_sensor_data_with_metadata completed.")
    
    return df, metadata


########################################
# Function: Compute EMG extreme flag using a fixed time window
########################################
def compute_emg_extreme_flag_window(df, window_time=1.3, column='EMG 1 (mV)', 
                                    threshold_high=1.0, threshold_low=-0.5, debug=False):
    """
    Computes a flag for each row indicating whether, within a fixed time window
    around the current row, there is at least one EMG value above 'threshold_high' and
    one below 'threshold_low'.
    
    Parameters:
      df (pd.DataFrame): DataFrame containing a 'Timestamp' column.
      window_time (float): Time window in seconds.
      column (str): Column name with EMG values.
      threshold_high (float): High threshold.
      threshold_low (float): Low threshold.
      debug (bool): If True, prints detailed debug information.
      
    Returns:
      pd.Series: Series of 0/1 flags.
    """
    time_diffs = df['Timestamp'].diff().dropna().dt.total_seconds()
    median_dt = time_diffs.median() if not time_diffs.empty else 0
    frame_count = int(round(window_time / median_dt)) if median_dt > 0 else 0
    if debug:
        print(f"[compute_emg_extreme_flag_window] Using a time window of {window_time} sec (~{frame_count} frames)")
    
    flags = []
    timestamps = df['Timestamp']
    values = df[column]
    for idx, current_time in timestamps.items():
        start_time = current_time - pd.Timedelta(seconds=window_time)
        end_time = current_time + pd.Timedelta(seconds=window_time)
        window_mask = (timestamps >= start_time) & (timestamps <= end_time)
        window_values = values[window_mask]
        flag = int((window_values > threshold_high).any() and (window_values < threshold_low).any())
        flags.append(flag)
    
    flag_series = pd.Series(flags, index=df.index)
    if debug:
        print(f"[compute_emg_extreme_flag_window] Output flags shape: {flag_series.shape}")
    else:
        print("compute_emg_extreme_flag_window completed.")
    return flag_series


########################################
# Function: Compute EMG extreme flag using a dynamic time window
########################################
def compute_emg_extreme_flag_dynamic_window(df, column='EMG 1 (mV)', threshold_high=1.0, 
                                              threshold_low=-0.5, debug=False):
    """
    Computes a dynamic extreme flag for each row by first determining a fixed-window flag,
    then adjusting the time window based on the nearest extreme events.
    
    Parameters:
      df (pd.DataFrame): DataFrame containing a 'Timestamp' column.
      column (str): Column name with EMG values.
      threshold_high (float): High threshold.
      threshold_low (float): Low threshold.
      debug (bool): If True, prints detailed debug information.
      
    Returns:
      pd.Series: Series of 0/1 flags.
    """
    fixed_flags = compute_emg_extreme_flag_window(df, window_time=1.3, column=column, 
                                                   threshold_high=threshold_high, threshold_low=threshold_low, debug=debug)
    extreme_times = df.loc[fixed_flags == 1, 'Timestamp']
    extreme_time_array = extreme_times.sort_values().values  # numpy array of timestamps
    
    new_flags = []
    dynamic_windows = []  # store δ (in seconds) for each row
    timestamps = df['Timestamp']
    values = df[column]
    
    for idx, current_time in timestamps.items():
        current_time_np = np.datetime64(current_time)
        pos = np.searchsorted(extreme_time_array, current_time_np)
        prev_extreme = extreme_time_array[pos - 1] if pos > 0 else None
        next_extreme = extreme_time_array[pos] if pos < len(extreme_time_array) else None
        
        if prev_extreme is not None and next_extreme is not None:
            delta_prev = (current_time_np - prev_extreme).astype('timedelta64[ns]').astype(float) / 1e9
            delta_next = (next_extreme - current_time_np).astype('timedelta64[ns]').astype(float) / 1e9
            delta_sec = min(delta_prev, delta_next)
        elif prev_extreme is not None:
            delta_sec = (current_time_np - prev_extreme).astype('timedelta64[ns]').astype(float) / 1e9
        elif next_extreme is not None:
            delta_sec = (next_extreme - current_time_np).astype('timedelta64[ns]').astype(float) / 1e9
        else:
            delta_sec = 0
        dynamic_windows.append(delta_sec)
        
        start_time = current_time - pd.Timedelta(seconds=delta_sec)
        end_time = current_time + pd.Timedelta(seconds=delta_sec)
        window_mask = (timestamps >= start_time) & (timestamps <= end_time)
        window_values = values[window_mask]
        flag = int((window_values > threshold_high).any() and (window_values < threshold_low).any())
        new_flags.append(flag)
    
    new_flags_series = pd.Series(new_flags, index=df.index)
    avg_dynamic_window = np.mean(dynamic_windows) if dynamic_windows else 0
    if debug:
        print(f"[compute_emg_extreme_flag_dynamic_window] Average dynamic window size: {avg_dynamic_window:.2f} sec")
    else:
        print("compute_emg_extreme_flag_dynamic_window completed.")
    return new_flags_series


########################################
# Function: Build global min/max dictionary
########################################
def build_global_min_max(df, columns_to_analyze, debug=False):
    """
    Builds and returns a dictionary mapping each column to its global min and max.
    
    Parameters:
      df (pd.DataFrame): DataFrame.
      columns_to_analyze (list): List of column names.
      debug (bool): If True, prints debug info.
      
    Returns:
      dict: Dictionary with min and max for each column.
    """
    global_dict = {}
    for col in columns_to_analyze:
        global_dict[col] = {'min': df[col].min(), 'max': df[col].max()}
    if debug:
        print(f"[build_global_min_max] Global min/max for columns: {global_dict}")
    else:
        print("build_global_min_max completed.")
    return global_dict


########################################
# Function: Compute window metrics
########################################
def compute_window_metrics(window_df, columns_to_analyze, global_min_max=None, debug=False):
    """
    Computes metrics (average, min, max) for a given window of data.
    
    Parameters:
      window_df (pd.DataFrame): DataFrame slice.
      columns_to_analyze (list): List of column names.
      global_min_max (dict, optional): Dictionary for global min/max comparison.
      debug (bool): If True, prints debug info.
      
    Returns:
      dict: Dictionary of computed metrics.
    """
    stats = {}
    for col in columns_to_analyze:
        w_min = window_df[col].min()
        w_max = window_df[col].max()
        w_avg = window_df[col].mean()
        stats[f'avg_{col}'] = w_avg
        stats[f'min_{col}'] = w_min
        stats[f'max_{col}'] = w_max
        if global_min_max is not None and col in global_min_max:
            g_min = global_min_max[col]['min']
            g_max = global_min_max[col]['max']
            stats[f'is_global_min_{col}'] = (w_min == g_min)
            stats[f'is_global_max_{col}'] = (w_max == g_max)
    if debug:
        print(f"[compute_window_metrics] Computed stats: {stats}")
    return stats


########################################
# Function: Analyze spikes in a given column
########################################
def analyze_spikes(df, col, window=50, global_min_max=None, debug=False):
    """
    Analyzes spikes in a given column using a specified window.
    
    Parameters:
      df (pd.DataFrame): DataFrame.
      col (str): Column name.
      window (int): Window size (number of rows) around the spike.
      global_min_max (dict, optional): Dictionary for global min/max.
      debug (bool): If True, prints debug info.
      
    Returns:
      pd.DataFrame: DataFrame with computed spike metrics.
    """
    spike_flag_col = f'{col}_spike_flag'
    flagged_indices = df.index[df[spike_flag_col] == 1]
    results = []
    columns_to_analyze = [
        'EMG 1 (mV)', 'ACC X (G)', 'ACC Y (G)', 'ACC Z (G)',
        'GYRO X (deg/s)', 'GYRO Y (deg/s)', 'GYRO Z (deg/s)'
    ]
    
    for idx in flagged_indices:
        start_idx = max(0, idx - window)
        end_idx = min(len(df) - 1, idx + window)
        window_df = df.loc[start_idx:end_idx]
        window_stats = compute_window_metrics(window_df, columns_to_analyze, global_min_max=global_min_max, debug=debug)
        window_stats['spike_index'] = idx
        window_stats['spike_column'] = col
        window_stats['spike_value'] = df.loc[idx, col]
        window_stats['window_start'] = start_idx
        window_stats['window_end'] = end_idx
        results.append(window_stats)
    if debug:
        print(f"[analyze_spikes] Processed {len(flagged_indices)} spikes for column {col}.")
    return pd.DataFrame(results)


########################################
# Function: Compare spike windows from EMG and ACC/GYRO
########################################
def compare_spike_windows(emg_spikes_df, acc_gyro_spikes_df, debug=False):
    """
    Compares spike windows from EMG and ACC/GYRO and returns merged information.
    
    Parameters:
      emg_spikes_df (pd.DataFrame): DataFrame from analyze_spikes for EMG.
      acc_gyro_spikes_df (pd.DataFrame): DataFrame from analyze_spikes for ACC/GYRO.
      debug (bool): If True, prints debug info.
      
    Returns:
      pd.DataFrame: Merged DataFrame.
    """
    rows = []
    for i, emg_row in emg_spikes_df.iterrows():
        emg_win_start = emg_row['window_start']
        emg_win_end = emg_row['window_end']
        overlapping_spikes = acc_gyro_spikes_df[
            (acc_gyro_spikes_df['spike_index'] >= emg_win_start) &
            (acc_gyro_spikes_df['spike_index'] <= emg_win_end)
        ]
        for j, spike_row in overlapping_spikes.iterrows():
            merged_dict = {
                'emg_spike_index': emg_row['spike_index'],
                'emg_spike_value': emg_row['spike_value'],
                'acc_gyro_spike_index': spike_row['spike_index'],
                'acc_gyro_spike_column': spike_row['spike_column'],
                'emg_window_avg': emg_row['avg_EMG 1 (mV)'],
                'acc_window_avg': spike_row.get('avg_ACC X (G)', None)
            }
            rows.append(merged_dict)
    if debug:
        print(f"[compare_spike_windows] Merged {len(rows)} overlapping spike events.")
    return pd.DataFrame(rows)


########################################
# Function: Mark throwing motion based on extreme flag windows
########################################
def mark_throwing_motion(df, extreme_flag_col='EMG_extreme_flag', window_time=1.3, debug=False):
    """
    Marks rows as part of the throwing motion based on extreme flag events.
    
    For each row where the specified extreme_flag_col is 1, mark all rows within ±(window_time/2) seconds 
    of that event's timestamp as part of the throwing motion by setting a new column 'ThrowingMotion' to 1.
    
    Parameters:
      df (pd.DataFrame): DataFrame with a 'Timestamp' column.
      extreme_flag_col (str): The column name that holds the extreme flag.
      window_time (float): Total duration (in seconds) for the throwing motion window.
      debug (bool): If True, prints detailed debug information.
      
    Returns:
      pd.DataFrame: Copy of the DataFrame with an added 'ThrowingMotion' column.
    """
    df = df.copy()
    df['ThrowingMotion'] = 0
    half_window = window_time / 2  # e.g., 0.65 seconds for a 1.3-second window
    
    # Debug: Show the number of extreme events.
    extreme_events = df.loc[df[extreme_flag_col] == 1, 'Timestamp']
    if debug:
        print(f"[mark_throwing_motion] Found {len(extreme_events)} extreme events. Using half window = {half_window} sec.")
    
    # Mark rows within the window of each extreme event.
    for t in extreme_events:
        start = t - pd.Timedelta(seconds=half_window)
        end = t + pd.Timedelta(seconds=half_window)
        mask = (df['Timestamp'] >= start) & (df['Timestamp'] <= end)
        df.loc[mask, 'ThrowingMotion'] = 1
        if debug:
            print(f"[mark_throwing_motion] Marking event at {t} (window: {start} to {end}).")
    
    if debug:
        total_marked = df['ThrowingMotion'].sum()
        print(f"[mark_throwing_motion] Total rows marked as ThrowingMotion: {total_marked}")
    else:
        print("mark_throwing_motion completed.")
    
    return df


########################################
# Function: Process a single CSV file
########################################
def process_file(file_path, debug=False):
    """
    Processes a single sensor CSV file:
      - Reads the file and its metadata.
      - Performs cleaning and type conversion.
      - Computes various flags and metrics.
      - Marks throwing motion.
    
    Parameters:
      file_path (str): Path to the CSV file.
      debug (bool): If True, prints detailed debug output.
    
    Returns:
      pd.DataFrame: Processed DataFrame.
    """
    if debug:
        print(f"\n[process_file] Processing file: {file_path}")
    else:
        print(f"Processing file: {os.path.basename(file_path)}")

    # Step 1: Read data and metadata.
    df, metadata = read_sensor_data_with_metadata(file_path, debug=debug)
    if debug:
        print(f"[process_file] DataFrame shape after reading: {df.shape}")
    else:
        print("Data read completed.")

    # Step 2: Display minimal summary if in debug mode.
    if debug:
        print(f"[process_file] Descriptive Statistics:\n{df.describe()}")
        print(f"[process_file] Data types:\n{df.dtypes}")
    else:
        print("Basic summary displayed.")

    # Step 3: Dynamically identify numeric sensor columns.
    base_names = ['ACC X (G)', 'ACC Y (G)', 'ACC Z (G)', 
                  'GYRO X (deg/s)', 'GYRO Y (deg/s)', 'GYRO Z (deg/s)']
    numeric_cols = []
    for base in base_names:
        matches = [col for col in df.columns if col.startswith(base)]
        numeric_cols.extend(matches)
    if debug:
        print(f"[process_file] Identified numeric sensor columns: {numeric_cols}")

    # Clean data: Remove rows with blank numeric values.
    mask = df[numeric_cols].apply(lambda col: col.astype(str).str.strip() == '').any(axis=1)
    if debug:
        print(f"[process_file] Rows with blank numeric values: {mask.sum()}")
    df = df[~mask]

    # Convert identified numeric columns to numeric type.
    for col in numeric_cols:
        try:
            df[col] = pd.to_numeric(df[col], errors='raise')
        except Exception as e:
            print(f"[process_file] Error converting column {col}: {e}")
            raise
    if debug:
        print(f"[process_file] Data shape after cleaning: {df.shape}")

    # Step 4: (Optional) Subset data; here we use the full dataset.
    print(f"[process_file] Data subset: {df.shape[0]} rows (full data used).")
    
    # (Optional) Compute overall min/max summary.
    min_max_df = pd.DataFrame({'min': df.min(), 'max': df.max()})
    if debug:
        print(f"[process_file] Overall min/max summary:\n{min_max_df}")
    else:
        print("Min/Max summary computed.")

    # Step 5: Create spike flags for ACC/GYRO columns.
    for col in numeric_cols:
        spike_flag_col = f'{col}_spike_flag'
        df[spike_flag_col] = ((df[col] > 1) | (df[col] < -0.5)).astype(int)
    print("Spike flags for ACC/GYRO created.")

    # Create spike flag for EMG (value > 1.0).
    emg_base = 'EMG 1 (mV)'
    emg_matches = [col for col in df.columns if col.startswith(emg_base)]
    if emg_matches:
        emg_col = emg_matches[0]
    else:
        raise KeyError(f"No column found starting with '{emg_base}'")
    emg_spike_flag_col = f'{emg_col}_spike_flag'
    df[emg_spike_flag_col] = (df[emg_col] > 1.0).astype(int)
    
    # Additional EMG flags.
    df['EMG_high_flag'] = (df[emg_col] > 1.0).astype(int)
    df['EMG_low_flag'] = (df[emg_col] < -0.5).astype(int)
    if debug:
        print(f"[process_file] EMG_high_flag, EMG_low_flag added. Count >1.0: {df['EMG_high_flag'].sum()}, "
              f"Count <-0.5: {df['EMG_low_flag'].sum()}")

    # Step 6: Compute fixed-window extreme flag for EMG.
    df['EMG_extreme_flag'] = compute_emg_extreme_flag_window(df, window_time=1.3, column=emg_col, debug=debug)
    if debug:
        print(f"[process_file] Fixed-window extreme flag count: {df['EMG_extreme_flag'].sum()}")

    # Step 7: Count unique extreme events in fixed window.
    unique_extreme_count = ((df['EMG_extreme_flag'] == 1) &
                            (df['EMG_extreme_flag'].shift(1).fillna(0) != 1)).sum()
    if debug:
        print(f"[process_file] Unique extreme events (fixed window): {unique_extreme_count}")

    # Step 8: Compute dynamic-window extreme flag for EMG.
    df['EMG_extreme_flag_dynamic'] = compute_emg_extreme_flag_dynamic_window(df, column=emg_col, debug=debug)
    if debug:
        print(f"[process_file] Dynamic-window extreme flag count: {df['EMG_extreme_flag_dynamic'].sum()}")
        unique_dynamic_extreme_count = ((df['EMG_extreme_flag_dynamic'] == 1) &
                                        (df['EMG_extreme_flag_dynamic'].shift(1).fillna(0) != 1)).sum()
        print(f"[process_file] Unique extreme events (dynamic window): {unique_dynamic_extreme_count}")

    # Step 9: Mark throwing motion based on fixed-window extreme flags.
    df = mark_throwing_motion(df, extreme_flag_col='EMG_extreme_flag', window_time=1.3, debug=debug)
    if debug:
        print(f"[process_file] ThrowingMotion rows count: {df['ThrowingMotion'].sum()}")

    print("File processing completed.\n")
    return df


########################################
# Main function: Process all files in a folder and output a single Parquet file
########################################
def main(debug=False, input_folder='./data/raw/', output_file='./data/processed/processed_pitch_data.parquet'):
    """
    Processes all CSV files in the specified folder, stacks them into one DataFrame,
    and writes the output to a Parquet file.
    
    Parameters:
      debug (bool): If True, prints detailed debug information.
      input_folder (str): Folder containing the CSV files.
      output_file (str): Path for the output Parquet file.
    
    Returns:
      pd.DataFrame: Final processed DataFrame.
    """
    # Ensure input folder exists.
    if not os.path.isdir(input_folder):
        raise FileNotFoundError(f"Input folder '{input_folder}' does not exist.")
    
    # Find all CSV files in the folder.
    csv_files = glob.glob(os.path.join(input_folder, '*.csv'))
    if not csv_files:
        raise FileNotFoundError("No CSV files found in the input folder.")
    
    if debug:
        print(f"[main] Found {len(csv_files)} CSV files in '{input_folder}'.")
    else:
        print(f"Found {len(csv_files)} CSV file(s).")

    processed_dfs = []
    for file in csv_files:
        df = process_file(file, debug=debug)
        # Optionally add a column to indicate source file.
        df['SourceFile'] = os.path.basename(file)
        processed_dfs.append(df)
    
    # Stack all DataFrames (row-wise).
    final_df = pd.concat(processed_dfs, ignore_index=True)
    if debug:
        print(f"[main] Final stacked DataFrame shape: {final_df.shape}")
    else:
        print("All files processed and stacked.")

    # Save final DataFrame to Parquet.
    output_dir = os.path.dirname(output_file)
    os.makedirs(output_dir, exist_ok=True)
    final_df.to_parquet(output_file, index=False)
    print(f"Final processed data saved to: {output_file}")
    
    return final_df


# Run the module when executed as a script.
if __name__ == "__main__":
    # Set debug=True for detailed output, or False for minimal output.
    processed_df = main(
        debug=True,
        input_folder='../../data/raw/three_sensored_emg_data/',         # Specify your input folder path here.
        output_file='../../data/processed/emg_pitch_data_processed.parquet'  # Specify your output file path here.
    )


# Granular Biomechanics Dataset from Theia

In [None]:
import pandas as pd
import mysql.connector
import logging
from datetime import datetime, timedelta
import numpy as np
from dotenv import load_dotenv
import os 
from pathlib import Path


# Check scipy version
import scipy
print(f"SciPy version: {scipy.__version__}")
from scipy.integrate import cumulative_trapezoid


# Get the current working directory as a Path object
base_path = Path(os.getcwd())

# Move two folders up using the 'parents' attribute
env_path = base_path.parents[1] / '.env'

# Load the .env file
load_dotenv(dotenv_path=env_path)
print(env_path)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_database_connection():
    """Create and return a database connection to the 'theia_pitching_db'."""
    return mysql.connector.connect(
        host=os.getenv("DB_HOST"),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASSWORD"),
        database=os.getenv("DB_DATABASE")
    )

def check_data_completeness(df, table_name, key_columns):
    """
    Check for missing/incomplete data in a given DataFrame.
    Logs the sum of nulls per column, unique non-null counts, and data types.
    If a 'time' column is present, calculates basic time gap statistics.
    """
    logger.info(f"\nChecking data completeness for {table_name}:")
    null_counts = df.isnull().sum()
    if null_counts.any():
        logger.warning(f"Null values found in {table_name}:")
        logger.warning(null_counts[null_counts > 0])
    else:
        logger.info("No null values found.")
    
    for col in df.columns:
        unique_count = df[col].nunique(dropna=True)
        col_type = df[col].dtype
        logger.info(f"Column '{col}': Type = {col_type}, Unique non-null values = {unique_count}")
    
    if 'time' in df.columns:
        time_stats = df.groupby('session_trial')['time'].apply(lambda x: x.diff().describe())
        logger.info(f"Time series statistics for {table_name} (averaged across sessions):\n{time_stats.mean()}")

from scipy.signal import welch

def calculate_spectral_features(signal_window, fs=100):
    """
    Calculate spectral features for a given window of signal data using Welch's method.
    Handles NaN values in the input signal by replacing them with 0.
    
    Parameters:
      signal_window (array-like): 1D array of signal values (e.g., valgus_torque over a time window)
      fs (int): Sampling frequency in Hz (default is 100)
      
    Returns:
      dict: A dictionary containing:
            - 'peak_freq': Frequency with the maximum power spectral density.
            - 'total_power': Total power computed via integration of the PSD.
    """
    # Replace NaNs with 0 in the signal
    clean_signal = np.nan_to_num(signal_window, nan=0.0)
    freqs, psd = welch(clean_signal, fs=fs)
    peak_freq = freqs[np.argmax(psd)]
    total_power = np.trapz(psd, freqs)
    return {
        'peak_freq': peak_freq,
        'total_power': total_power
    }



def get_granular_time_series_data(filter_type='LAST_DAY', specific_date=None, specific_month=None, start_date=None, end_date=None):
    """
    Retrieves granular, frame-level time series data from 'theia_pitching_db'
    with flexible date filtering options and returns all key metrics for injury analysis.
    
    Updates include:
      - Converting the 'time' column from Decimal to float to prevent type conflicts.
      - Using millisecond resolution for ongoing_timestamp.
      - Adding the pitch phase marker as defined in the query.
    """
    conn = get_database_connection()
    cursor = conn.cursor()
    
    # Build the dynamic filtering part of the date_filtered_sessions CTE.
    date_filter_cte = """
    WITH date_filtered_sessions AS (
        SELECT
            t.session_trial,
            t.trial,
            t.time AS trial_time,
            TIMESTAMP(s.date, t.time) AS session_datetime,
            t.pitch_type,
            t.handedness,
            s.date,
            s.session,
            s.level,
            s.lab,
            s.height_meters,
            s.mass_kilograms,
            u.name AS athlete_name,
            u.dob AS athlete_dob,
            u.traq AS athlete_traq
        FROM `trials` t
        JOIN `sessions` s ON t.session = s.session
        JOIN `users` u ON s.user = u.user
        WHERE 
    """
    
    # Append the appropriate date filter based on the filter_type
    if filter_type == 'LAST_DAY':
        date_filter_cte += "s.date = (SELECT MAX(date) FROM sessions)"
    elif filter_type == 'LAST_5_DAYS':
        date_filter_cte += "s.date >= DATE_SUB((SELECT MAX(date) FROM sessions), INTERVAL 4 DAY)"
    elif filter_type == 'LAST_MONTH':
        date_filter_cte += "s.date >= DATE_SUB((SELECT MAX(date) FROM sessions), INTERVAL 30 DAY)"
    elif filter_type == 'SPECIFIC_DATE':
        if not specific_date:
            raise ValueError("specific_date is required for SPECIFIC_DATE filter_type")
        date_filter_cte += f"s.date = '{specific_date}'"
    elif filter_type == 'SPECIFIC_MONTH':
        if not specific_month:
            raise ValueError("specific_month is required for SPECIFIC_MONTH filter_type")
        date_filter_cte += f"DATE_FORMAT(s.date, '%Y-%m') = '{specific_month}'"
    elif filter_type == 'DATE_RANGE':
        if not start_date or not end_date:
            raise ValueError("Both start_date and end_date are required for DATE_RANGE filter_type")
        date_filter_cte += f"s.date BETWEEN '{start_date}' AND '{end_date}'"
    else:
        raise ValueError(f"Invalid filter_type: {filter_type}")
    
    date_filter_cte += ")"
    
    # CTE to retrieve event markers for the pitch phases
    pitch_phases_cte = """
    , pitch_phases AS (
        SELECT 
            e.session_trial,
            e.PKH_time,
            e.FP_v5_time,
            e.MER_time,
            e.`BR_time` as ball_release_time,
            e.`i_BR` as ball_release_frame,
            e.MAD_time
        FROM `events` e
        INNER JOIN date_filtered_sessions dfs 
            ON e.session_trial = dfs.session_trial
    )
    """
    
    # CTE for frame-level joint data with added phase_marker
    frame_level_data_cte = """
    , frame_level_data AS (
        SELECT 
            ja.session_trial,
            ja.time,
            ja.shoulder_angle_x,
            ja.shoulder_angle_y,
            ja.shoulder_angle_z,
            ja.elbow_angle_x,
            ja.elbow_angle_y,
            ja.elbow_angle_z,
            ja.torso_angle_x,
            ja.torso_angle_y,
            ja.torso_angle_z,
            ja.pelvis_angle_x,
            ja.pelvis_angle_y,
            ja.pelvis_angle_z,
            COALESCE(jv.shoulder_velo_x, 0.0) AS shoulder_velo_x,
            COALESCE(jv.shoulder_velo_y, 0.0) AS shoulder_velo_y,
            COALESCE(jv.shoulder_velo_z, 0.0) AS shoulder_velo_z,
            jv.elbow_velo_x,
            jv.elbow_velo_y,
            jv.elbow_velo_z,
            jv.torso_velo_x,
            jv.torso_velo_y,
            jv.torso_velo_z,
            ABS(ja.torso_angle_z - ja.pelvis_angle_z) AS trunk_pelvis_dissociation,
            pp.ball_release_time,
            -- Determine pitch phase using event markers
            CASE 
                WHEN ja.time <= pp.PKH_time THEN 'Wind-Up'
                WHEN ja.time <= pp.FP_v5_time THEN 'Stride'
                WHEN ja.time <= pp.MER_time THEN 'Arm Cocking'
                WHEN ja.time <= pp.ball_release_time THEN 'Arm Acceleration'
                WHEN ja.time <= pp.MAD_time THEN 'Arm Deceleration'
                ELSE 'Follow Through'
            END AS pitch_phase,
            pp.PKH_time AS phase_marker  -- added key phase marker from events
        FROM `joint_angles` ja
        INNER JOIN date_filtered_sessions dfs ON ja.session_trial = dfs.session_trial
        INNER JOIN `joint_velos` jv ON ja.session_trial = jv.session_trial AND ja.time = jv.time
        LEFT JOIN pitch_phases pp ON ja.session_trial = pp.session_trial
    )
    """
    
    # Build final query with additional joins and updated timestamp calculation.
    query = (
        date_filter_cte +
        pitch_phases_cte +
        frame_level_data_cte +
        """
    SELECT 
        dfs.athlete_name,
        dfs.athlete_dob,
        dfs.athlete_traq,
        dfs.height_meters,
        dfs.mass_kilograms,
        dfs.level AS athlete_level,
        dfs.date AS session_date,
        dfs.trial_time AS session_time,
        dfs.lab,
        dfs.session,
        dfs.trial,
        dfs.pitch_type,
        dfs.handedness,
        -- Updated ongoing_timestamp calculation with millisecond resolution
        DATE_ADD(dfs.session_datetime, INTERVAL fld.time SECOND) AS ongoing_timestamp,
        fld.*,
        en.shoulder_energy_transfer,
        en.shoulder_energy_generation,
        en.elbow_energy_transfer,
        en.elbow_energy_generation,
        en.lead_knee_energy_transfer,
        en.lead_knee_energy_generation,

        /*-- Force plate data
        fp.lead_force_x,
        fp.lead_force_y,
        fp.lead_force_z,
        fp.lead_force_mag,
        fp.rear_force_x,
        fp.rear_force_y,
        fp.rear_force_z,
        fp.rear_force_mag,*/

        jf.elbow_force_x,
        jf.elbow_force_y,
        jf.elbow_force_z,
        jf.shoulder_upper_arm_force_x,
        jf.shoulder_upper_arm_force_y,
        jf.shoulder_upper_arm_force_z,
        COALESCE(jm.elbow_moment_x, 0.0) AS elbow_moment_x,
        COALESCE(jm.elbow_moment_y, 0.0) AS elbow_moment_y,
        COALESCE(jm.elbow_moment_z, 0.0) AS elbow_moment_z,
        jm.shoulder_thorax_moment_x,
        jm.shoulder_thorax_moment_y,
        jm.shoulder_thorax_moment_z,
        p.pitch_speed_mph,
        p.max_shoulder_internal_rotational_velo,
        p.elbow_varus_moment,
        ll.forearm_length
    FROM frame_level_data fld
    LEFT JOIN `energetics` en ON fld.session_trial = en.session_trial AND fld.time = en.time
    LEFT JOIN `force_plates` fp ON fld.session_trial = fp.session_trial AND fld.time = fp.time
    LEFT JOIN `joint_forces` jf ON fld.session_trial = jf.session_trial AND fld.time = jf.time
    LEFT JOIN `joint_moments` jm ON fld.session_trial = jm.session_trial AND fld.time = jm.time
    LEFT JOIN `poi` p ON fld.session_trial = p.session_trial
    LEFT JOIN `limb_lengths` ll ON fld.session_trial = ll.session_trial
    JOIN date_filtered_sessions dfs ON fld.session_trial = dfs.session_trial
    ORDER BY 
        dfs.date,
        dfs.trial_time,
        fld.session_trial,
        fld.time;
    """
    )
    
    logger.info(f"Executing updated query with filter_type: {filter_type}")
    cursor.execute(query)
    rows = cursor.fetchall()
    columns = [desc[0] for desc in cursor.description]
    
    df = pd.DataFrame(rows, columns=columns)
    cursor.close()
    conn.close()
    
    # Convert 'time' column to numeric (float) to resolve Decimal type issues from SQL import.
    df['time'] = pd.to_numeric(df['time'], errors='coerce').fillna(0)
    
    logger.info(f"Query returned {len(df)} rows and {len(df.columns)} columns.")
    return df

    
def trapz_integration(x):
    return np.trapz(x)

def calculate_dynamic_phase_weights(group):
    """
    Calculate dynamic phase weights for a given group while handling Decimal type conflicts.
    
    Steps:
      1. Convert 'time' to float to avoid Decimal arithmetic issues.
      2. Compute duration and perform a placeholder normalization.
      3. Compute torque integration normalization (also placeholder logic).
      4. Return a dictionary with the pitch phase as the key and the calculated weight as the value.
    """
    # Ensure 'time' column is float to avoid type conflicts (Decimal vs. float)
    time_vals = group['time'].astype(float)
    
    # Calculate the duration and normalize (placeholder: returns 1 if nonzero)
    duration = time_vals.max() - time_vals.min()
    duration_norm = duration / duration if duration != 0 else 0
    
    # Calculate torque integration normalization (placeholder: returns 1 if nonzero)
    torque_int = trapz_integration(group['valgus_torque'].astype(float))
    torque_int_norm = torque_int / torque_int if torque_int != 0 else 0
    
    shoulder_mean = group['shoulder_ang_velo'].mean()
    
    # Combine metrics with placeholder weights (update as needed for proper normalization)
    weight = (
        0.4 * torque_int_norm + 
        0.3 * duration_norm +
        0.2 * (group['valgus_torque'].max() / group['valgus_torque'].max() if group['valgus_torque'].max() != 0 else 0) +
        0.1 * (shoulder_mean / shoulder_mean if shoulder_mean != 0 else 0)
    )
    return {group['pitch_phase'].iloc[0]: weight}




def compute_phase_aware_cumulative(group):
    """Dynamic phase weight implementation"""
    # Calculate session-specific weights
    phase_weights = calculate_dynamic_phase_weights(group)
    
    cum_values = []
    current_phase = None
    current_cum = 0
    
    for idx, row in group.iterrows():
        if row['pitch_phase'] != current_phase:
            current_phase = row['pitch_phase']
            current_cum = 0
            
        weight = phase_weights.get(current_phase, 1.0)
        dt = row['time_diff_lead']
        d_torque = row['torque_diff_lead']
        
        current_cum += d_torque * weight * dt
        cum_values.append(current_cum)
    
    return pd.Series(cum_values, index=group.index)

def validate_phase_weights(df):
    """Ensure automated weights match biomechanical expectations"""
    expected_pattern = {
        'Wind-Up': (0.9, 1.2),
        'Arm Acceleration': (1.9, 2.3),  # Should be highest
        'Arm Deceleration': (1.7, 2.0)
    }
    
    sample_weights = df.groupby('session_trial').apply(
        lambda g: calculate_dynamic_phase_weights(g)
    ).explode().groupby(level=1).mean()
    
    violations = []
    for phase, bounds in expected_pattern.items():
        if not (bounds[0] <= sample_weights[phase] <= bounds[1]):
            violations.append(f"{phase}: {sample_weights[phase]:.2f}")
    
    if violations:
        logger.error(f"Phase weight violations:\n{', '.join(violations)}")
        return False
    return True


def compute_cumulative_exposure(group):
    """
    Compute the cumulative valgus torque exposure for a group (session_trial).
    This function logs key intermediate outputs and handles any NaN values in the torque array
    by replacing them with 0.
    """
    logger.info(f"Processing group: {group.name}")
    
    # Log original ongoing_timestamp values and their data type.
    logger.info("Original ongoing_timestamp values:")
    logger.info(group['ongoing_timestamp'].head(5))
    logger.info(f"Data type: {group['ongoing_timestamp'].dtype}")
    
    try:
        # Convert ongoing_timestamp to seconds (from ns)
        times = group['ongoing_timestamp'].astype('int64') / 1e9
    except Exception as e:
        logger.error(f"Error converting ongoing_timestamp in group {group.name}: {e}")
        raise
    logger.info("Converted times (seconds):")
    logger.info(times.head(5))
    
    # Ensure no NaNs in time values
    assert not times.isnull().any(), f"NaN found in times for group {group.name}"
    
    # Retrieve torque values and log them.
    torque = group['valgus_torque'].values
    logger.info("Torque values:")
    logger.info(torque[:5])
    
    # If any NaNs are found in torque, log a warning and replace them with 0.
    if np.isnan(torque).any():
        logger.warning(f"NaN detected in torque for group {group.name}; replacing NaNs with 0.")
        torque = np.nan_to_num(torque, nan=0.0)
    
    try:
        # Perform cumulative integration using the trapezoidal rule.
        cum = cumulative_trapezoid(torque, times, initial=0)
    except Exception as e:
        logger.error(f"Error during cumulative integration in group {group.name}: {e}")
        raise
    logger.info("Cumulative integration result:")
    logger.info(cum[:5])
    
    return pd.Series(cum, index=group.index)


def compute_armcock_acc_cumulative(group):
    cum_values = []
    current_cum = 0
    phase_mask = group['pitch_phase'].isin(['Arm Cocking', 'Arm Acceleration'])
    
    for idx, (time, torque, phase) in enumerate(zip(group['ongoing_timestamp'], 
                                                    group['valgus_torque'], 
                                                    group['pitch_phase'])):
        if not phase_mask.iloc[idx]:
            current_cum = 0
        else:
            dt = (group['ongoing_timestamp'].iloc[idx] - 
                 group['ongoing_timestamp'].iloc[idx-1]).total_seconds() if idx > 0 else 0
            current_cum += torque * dt
        cum_values.append(current_cum)
    
    return pd.Series(cum_values, index=group.index)


def process_valgus_features(df):
    """
    Process the DataFrame to compute additional valgus torque features for ML.
    Enhancements include:
      - Phase-specific cumulative torque calculation.
      - Velocity-scaled torque feature.
      - Peak torque identification.
      - Enhanced data validation.
      - Robust conversion of shoulder velocity columns to numeric.
      - Creation of an alias for 'pitch_phase_biomech' to ensure compatibility with plotting code.
    """
    # ----------------- Step A: Validate Input -----------------
    logger.info("Validating ongoing_timestamp column before processing valgus features.")
    df['ongoing_timestamp'] = pd.to_datetime(df['ongoing_timestamp'], errors='raise')
    
    if 'pitch_phase' not in df.columns:
        raise KeyError("Missing required column 'pitch_phase' for phase-aware calculations")
    
    # ----------------- Step B: Compute Valgus Torque -----------------
    df['valgus_torque'] = -pd.to_numeric(df['elbow_moment_z'], errors='coerce').fillna(0)
    df['valgus_torque'] = df['valgus_torque'].mask(df['valgus_torque'] < 0, 0)
    
    df['ball_release_time'] = pd.to_numeric(df['ball_release_time'], errors='coerce')
    # Time to Ball Release
    df['time_to_br'] = df.groupby('session_trial').apply(
        lambda g: g['ball_release_time'] - g['time']
    ).reset_index(level=0, drop=True)

    df['time_diff_lead'] = df.groupby('session_trial')['ongoing_timestamp'].shift(-1) - df['ongoing_timestamp']
    df['time_diff_lead'] = df['time_diff_lead'].dt.total_seconds()
    df['torque_diff_lead'] = df.groupby('session_trial')['valgus_torque'].shift(-1) - df['valgus_torque']
    
    # Valgus Impulse: Cumulative sum of torque leading up to ball release
    df['valgus_impulse'] = (df['valgus_torque'] * 
                           df['time_diff_lead']).cumsum()

    # ----------------- Step C: Validate and Convert Shoulder Velocity Columns -----------------
    vel_cols = ['shoulder_velo_x', 'shoulder_velo_y', 'shoulder_velo_z']
    logger.info("\nShoulder velocity column dtypes before conversion:")
    logger.info(df[vel_cols].dtypes)
    logger.info("\nSample shoulder velocity values before conversion:")
    logger.info(df[vel_cols].head(3))
    
    for col in vel_cols:
        df[col] = pd.to_numeric(df[col], errors='coerce')
    
    logger.info("\nPost-conversion null counts for velocity columns:")
    logger.info(df[vel_cols].isnull().sum())
    logger.info("\nPost-conversion dtypes for velocity columns:")
    logger.info(df[vel_cols].dtypes)
    
    # ----------------- Step D: Compute Angular Velocity and Velocity-Scaled Torque -----------------
    df['shoulder_ang_velo'] = np.sqrt(
        df['shoulder_velo_x']**2 + 
        df['shoulder_velo_y']**2 + 
        df['shoulder_velo_z']**2
    )
    df['velocity_scaled_torque'] = df['valgus_torque'] * df['shoulder_ang_velo']
    
    # ----------------- Step E: Compute Temporal Derivatives -----------------
    df['torque_derivative'] = df.groupby('session_trial')['valgus_torque'].diff() / \
        df.groupby('session_trial')['ongoing_timestamp'].diff().dt.total_seconds()
    

    # ----------------- Step F: Apply Cumulative Calculations -----------------
    df['phase_weighted_cumulative'] = df.groupby(
        ['session_trial', 'pitch_phase'], 
        group_keys=False
    ).apply(compute_phase_aware_cumulative)
    
    df['cumulative_valgus'] = df.groupby('session_trial', group_keys=False).apply(compute_cumulative_exposure)
    
    df['critical_phase'] = df['pitch_phase'].isin(['Arm Cocking', 'Arm Acceleration'])
    df['cumulative_valgus_phase_armcock_acc'] = df.groupby('session_trial', group_keys=False).apply(compute_armcock_acc_cumulative)
    
    # ----------------- Step G: Identify Peak Torque -----------------
    df['peak_torque_marker'] = df.groupby('session_trial')['valgus_torque'].transform(
        lambda x: x == x.max()
    )
    
    # Remove helper columns used in cumulative calculations
    df.drop(columns=['time_diff_lead', 'torque_diff_lead'], inplace=True)
    
    # # ----------------- Step H: Create Alias for Pitch Phase -----------------
    # # This ensures that downstream plotting (which expects 'pitch_phase_biomech') works properly.
    # if 'pitch_phase_biomech' not in df.columns and 'pitch_phase' in df.columns:
    #     df['pitch_phase_biomech'] = df['pitch_phase']
    
    logger.info("Finished processing valgus features. Sample of updated DataFrame:")
    logger.info(df.head())
    
    return df




def validate_valgus_calculation(df):
    """
    Validate that all required columns for valgus torque calculations are present and reasonable.
    Raises an error if key columns are missing and logs a warning if more than 10% of the computed
    valgus torque values are zero.
    """
    required_columns = [
        'elbow_moment_x', 'elbow_moment_y', 'elbow_moment_z',
        'forearm_length', 'mass_kilograms', 'valgus_torque'
    ]
    missing = [col for col in required_columns if col not in df.columns]
    if missing:
        raise ValueError(f"Missing valgus calculation columns: {missing}")
    
    zero_moment = df[df['valgus_torque'] == 0]
    if len(zero_moment) > 0.1 * len(df):
        logger.warning(">10% zero valgus torque values detected")


def validate_valgus_calculation(df):
    """
    Validate that all required columns for valgus torque calculations are present and reasonable.
    Raises an error if key columns are missing and logs a warning if more than 10% of the computed
    valgus torque values are zero.
    Also logs a null report for key torque columns and checks for unexpected negative values.
    """
    required_columns = [
        'elbow_moment_x', 'elbow_moment_y', 'elbow_moment_z',
        'forearm_length', 'mass_kilograms', 'valgus_torque',
        'shoulder_velo_x', 'shoulder_velo_y', 'shoulder_velo_z'
    ]
    missing = [col for col in required_columns if col not in df.columns]
    if missing:
        raise ValueError(f"Missing valgus calculation columns: {missing}")
    
    # New: Log null counts for key columns.
    null_report = df[['elbow_moment_z', 'valgus_torque']].isnull().sum()
    logger.info(f"Null values report:\n{null_report}")
    
    # New: Check for negative torque values after masking.
    if (df['valgus_torque'] < 0).any():
        logger.error("Negative valgus_torque values detected after masking")
    
    zero_moment = df[df['valgus_torque'] == 0]
    if len(zero_moment) > 0.1 * len(df):
        logger.warning(">10% zero valgus torque values detected")


# ----------------------- Main Execution Block -----------------------

if __name__ == "__main__":
    # logger.info("Requesting granular time series dataset with LAST_DAY filter...")
    # df_last_day = get_granular_time_series_data(filter_type='LAST_DAY')
    # logger.info("LAST_DAY filter complete. Displaying head:")
    # print(df_last_day.head(10))
    # logger.info("Displaying pitch phase data from last day:")
    # print("\nPitch phase values:")
    # print(df_last_day['pitch_phase'].value_counts().sort_index())

    # logger.info("Requesting granular time series dataset with LAST_5_DAYS filter...")
    # df_last_5 = get_granular_time_series_data(filter_type='LAST_5_DAYS')
    # logger.info("LAST_5_DAYS filter complete. Displaying head:")
    # print(df_last_5.head(10))
    
    logger.info("Requesting granular time series dataset with SPECIFIC_DATE filter (2025-02-14)...")
    df_specific = get_granular_time_series_data(filter_type='SPECIFIC_DATE', specific_date='2025-02-14')
    logger.info("SPECIFIC_DATE filter complete. Displaying head:")
    print(df_specific.head(10))
    
    # New test for DATE_RANGE filter
    logger.info("Requesting granular time series dataset with DATE_RANGE filter (2025-02-14 to 2025-02-19)...")
    df_date_range = get_granular_time_series_data(
        filter_type='DATE_RANGE', 
        start_date='2025-02-13', 
        end_date='2025-02-20'
    )
    logger.info("DATE_RANGE filter complete. Displaying head:")
    print(df_date_range.head(10))

    # Process the dataset to compute valgus torque features
    logger.info("Processing valgus torque features...")
    df_specific = process_valgus_features(df_specific)
    df_date_range = process_valgus_features(df_date_range)
    # Validate the computed valgus features
    validate_valgus_calculation(df_specific)
    
    logger.info("\nChecking ongoing_timestamp column:")
    logger.info(f"Number of unique timestamps: {df_specific['ongoing_timestamp'].nunique()}")
    logger.info("\nTimestamp range:")
    logger.info(f"Earliest: {df_specific['ongoing_timestamp'].min()}")
    logger.info(f"Latest: {df_specific['ongoing_timestamp'].max()}")
    logger.info("\nSample of timestamps:")
    print(df_specific['ongoing_timestamp'].head())
    
    logger.info("\nColumn information (including new valgus features):")
    for col in df_specific.columns:
        logger.info(f"\nColumn: {col}")
        logger.info(f"Data type: {df_specific[col].dtype}")
        logger.info(f"Number of unique values: {df_specific[col].nunique()}")
        logger.info(f"Number of null values: {df_specific[col].isnull().sum()}")
        if df_specific[col].dtype in ['object', 'category']:
            logger.info("Sample unique values:")
            print(df_specific[col].unique()[:5])
        elif df_specific[col].dtype in ['int64', 'float64']:
            logger.info("Numeric summary:")
            print(df_specific[col].describe())
    
    #-----------------DROPPING FORCE PLATE COLUMNS (if no force plate data)--------------------------------
    force_plate_cols = [col for col in df_specific.columns if 'force_' in col.lower()]
    logger.info(f"\nDropping {len(force_plate_cols)} force plate columns:")
    for col in force_plate_cols:
        logger.info(f"- {col}")
    df_specific = df_specific.drop(columns=force_plate_cols)
    logger.info(f"Remaining columns after drop: {len(df_specific.columns)}")
    
    # Save specific date dataframe to parquet
    output_path = '../../data/processed/ml_datasets/granular/granular_joint_details.parquet'
    logger.info(f"Saving specific date data to {output_path}")
    # Log shape and basic checks before saving
    logger.info(f"\nDataFrame shape: {df_specific.shape}")
    
    logger.info("\nBasic phase checks:")
    logger.info(f"Number of zeros in pitch_phase: {(df_specific['pitch_phase'] == 0).sum()}")
    logger.info(f"Number of nulls in pitch_phase: {df_specific['pitch_phase'].isnull().sum()}")
    
    logger.info("\nValue counts in pitch_phase:")
    logger.info(df_specific['pitch_phase'].value_counts())
    
    # Check for any null values across all columns
    null_cols = df_specific.columns[df_specific.isnull().any()].tolist()
    if len(null_cols) > 0:
        logger.info("\nColumns containing null values:")
        for col in null_cols:
            logger.info(f"- {col}: {df_specific[col].isnull().sum()} nulls")
    else:
        logger.info("\nNo null values found in any columns")
        
    logger.info("\nChecking for overlapping pitch phases, time gaps and phase durations per trial...")
    df_date_range.to_parquet(output_path)
    logger.info("Save complete")

    # Check null sums and column list
    logger.info("\nChecking null values across all columns:")
    null_sums = df_specific.isnull().sum()
    logger.info("\nColumns with null values:")
    for col, null_count in null_sums[null_sums > 0].items():
        logger.info(f"- {col}: {null_count} nulls")

    logger.info("\nFull column list:")
    for col in sorted(df_specific.columns):
        logger.info(f"- {col}")


    # Also show summary statistics per phase
    logger.info("\nSummary statistics of cumulative valgus per pitch phase:")
    phase_valgus_stats = df_specific.groupby('pitch_phase')['cumulative_valgus'].describe()
    logger.info(phase_valgus_stats)
    # Plot the progression of cumulative valgus over time
    logger.info("\nCreating plot of cumulative_valgus_phase_armcock_acc progression:")
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12,6))
    
    # Plot lines for each pitch phase
    for phase in df_specific['pitch_phase'].unique():
        phase_data = df_specific[df_specific['pitch_phase'] == phase]
        plt.plot(phase_data['ongoing_timestamp'], 
                phase_data['cumulative_valgus_phase_armcock_acc'],
                label=phase, alpha=0.7)
        
        # Add vertical lines for min and max timestamps
        min_time = phase_data['ongoing_timestamp'].min()
        max_time = phase_data['ongoing_timestamp'].max()
        plt.axvline(x=min_time, color='gray', linestyle='--', alpha=0.3)
        plt.axvline(x=max_time, color='gray', linestyle='--', alpha=0.3)
        
        # Add text labels for min/max lines
        y_pos = plt.ylim()[1]
        plt.text(min_time, y_pos, f'{phase} start', rotation=90, verticalalignment='top')
        plt.text(max_time, y_pos, f'{phase} end', rotation=90, verticalalignment='top')
    
    plt.xlabel('Time')
    plt.ylabel('Cumulative Valgus (Phase Arm Cock Acc)')
    plt.title('Progression of Cumulative Valgus by Pitch Phase')
    plt.legend()
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Show the plot
    plt.tight_layout()
    plt.show()
    
    # Also show summary statistics per phase
    logger.info("\nSummary statistics of cumulative_valgus_phase_armcock_acc per pitch phase:")
    phase_valgus_stats = df_date_range.groupby('pitch_phase')['cumulative_valgus_phase_armcock_acc'].describe()
    logger.info(phase_valgus_stats)


# Combine EMG and Biomech

Granular Datasets:
Resample/interpolate the biomechanics dataset to join evenly onto the EMG data: No changes
Bio dataset with EMG filtered out dataset: is_interpolated filter: filter for non interpolated data if you want to take away emg data for a bio dataset without interpolated added columns (will filter out EMG so they are on the bio frequency)
EMG dataset with phases added on: create a interpolated column list so we can differ that from the non and create a EMG dataset with pitch phases added on (creating the simplistic EMG dataset with phases added on, no fake data involved and straight to the muscles dataset)

In [None]:
"""
Module: emg_biomech_inner_join.py

Goal:
    To add biomech data to the EMG data after interpolating the biomech dataset
    granular enough to provide a datapoint for each EMG metric. In addition to
    interpolating the biomech metrics, we add an identifier column indicating
    whether a row was interpolated (new) or originally present. We focus on the
    inner join workflow, retaining only EMG rows that receive complete biomech data.

Usage:
    Run this module as a script to load, process, and join the datasets.
"""

import pandas as pd
import numpy as np


def load_biomech_data(biomech_path, debug=False):
    """
    Load and prepare biomechanical data from a parquet file.
    Ensures the datetime column is parsed.
    
    Parameters:
        biomech_path (str): Path to the parquet file.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        pd.DataFrame: Loaded biomechanical data.
    """
    df = pd.read_parquet(biomech_path)
    if 'ongoing_timestamp' not in df.columns:
        raise ValueError("Biomech data missing 'ongoing_timestamp' column.")
    df['ongoing_timestamp'] = pd.to_datetime(df['ongoing_timestamp'])
    df['datetime'] = df['ongoing_timestamp']
    if debug:
        print(f"[DEBUG] load_biomech_data: DataFrame shape = {df.shape}")
        print(f"[DEBUG] Datetime range: {df['datetime'].min()} to {df['datetime'].max()}")
        print(f"[DEBUG] New columns: { {col: str(dtype) for col, dtype in df[['ongoing_timestamp', 'datetime']].dtypes.items()} }")
    else:
        print("Biomech data loaded.")
    return df


def load_emg_data(emg_path, debug=False):
    """
    Load and prepare EMG data from a CSV file.
    Parses datetime using either 'Date/Time' and/or 'Timestamp' columns.
    
    Parameters:
        emg_path (str): Path to the CSV file.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        pd.DataFrame: Loaded EMG data.
    """
    df = pd.read_parquet(emg_path)
    if 'Date/Time' in df.columns and 'Timestamp' in df.columns:
        df['Date/Time_parsed'] = pd.to_datetime(df['Date/Time'])
        df['Timestamp_parsed'] = pd.to_datetime(df['Timestamp'])
        df['emg_time'] = df['Timestamp']
        df['datetime'] = df['Timestamp_parsed']
    elif 'Date/Time' in df.columns:
        df['emg_time'] = df['Date/Time']
        df['datetime'] = pd.to_datetime(df['Date/Time'])
    elif 'Timestamp' in df.columns:
        df['emg_time'] = df['Timestamp']
        df['datetime'] = pd.to_datetime(df['Timestamp'])
    else:
        raise ValueError("EMG data missing a datetime column ('Date/Time' or 'Timestamp').")
    
    if debug:
        print(f"[DEBUG] load_emg_data: DataFrame shape = {df.shape}")
        print(f"[DEBUG] Datetime range: {df['datetime'].min()} to {df['datetime'].max()}")
    else:
        print("EMG data loaded.")
    return df


def compute_time_steps(df, debug=False):
    """
    Compute time steps based on the datetime column.
    
    Parameters:
        df (pd.DataFrame): Input DataFrame.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        pd.DataFrame: DataFrame with an added 'time_step' column.
    """
    df = df.sort_values('datetime').reset_index(drop=True)
    df['time_step'] = df['datetime'].diff()
    if debug:
        print(f"[DEBUG] compute_time_steps: DataFrame shape = {df.shape}")
        # Print a summary of the new column
        print(f"[DEBUG] 'time_step' sample dtypes: {df['time_step'].dtype}, sample values: {df['time_step'].head(3).tolist()}")
    else:
        print("Time steps computed.")
    return df


def sort_dataframes(biomech_df, emg_df, debug=False):
    """
    Sort both DataFrames by datetime.
    
    Parameters:
        biomech_df (pd.DataFrame): Biomech DataFrame.
        emg_df (pd.DataFrame): EMG DataFrame.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: Sorted biomech and EMG DataFrames.
    """
    biomech_df = biomech_df.sort_values('datetime').reset_index(drop=True)
    emg_df = emg_df.sort_values('datetime').reset_index(drop=True)
    if debug:
        print(f"[DEBUG] sort_dataframes: Biomech shape = {biomech_df.shape}, EMG shape = {emg_df.shape}")
    else:
        print("DataFrames sorted.")
    return biomech_df, emg_df


def filter_biomech_by_emg_range(biomech_df, emg_df, buffer_minutes=30, debug=False):
    """
    Filter the biomech DataFrame for each day in the EMG data so that only rows 
    falling within the EMG datetime range (plus a buffer) are retained.
    
    Parameters:
        biomech_df (pd.DataFrame): Biomech DataFrame.
        emg_df (pd.DataFrame): EMG DataFrame.
        buffer_minutes (int): Minutes to buffer on each side.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        pd.DataFrame: Filtered biomech DataFrame.
    """
    biomech_df = biomech_df.copy()
    emg_df = emg_df.copy()
    # Add temporary date columns
    biomech_df['date'] = biomech_df['datetime'].dt.date
    emg_df['date'] = emg_df['datetime'].dt.date

    filtered_list = []
    for day in emg_df['date'].unique():
        emg_day = emg_df[emg_df['date'] == day]
        day_min = emg_day['datetime'].min()
        day_max = emg_day['datetime'].max()
        day_min_buffered = day_min - pd.Timedelta(minutes=buffer_minutes)
        day_max_buffered = day_max + pd.Timedelta(minutes=buffer_minutes)
        if debug:
            print(f"[DEBUG] filter_biomech_by_emg_range: Day {day} - EMG range: {day_min} to {day_max}, Buffered: {day_min_buffered} to {day_max_buffered}")
        biomech_day = biomech_df[(biomech_df['datetime'] >= day_min_buffered) & 
                                 (biomech_df['datetime'] <= day_max_buffered)]
        filtered_list.append(biomech_day)
    filtered_biomech = pd.concat(filtered_list, ignore_index=True)
    filtered_biomech = filtered_biomech.drop(columns=['date'])
    if debug:
        print(f"[DEBUG] filter_biomech_by_emg_range: Filtered biomech shape = {filtered_biomech.shape}")
    else:
        print("Biomech data filtered by EMG range.")
    return filtered_biomech





def check_emg_integrity(original_emg, joined_df, join_description="join", debug=False):
    """
    Check that the joined DataFrame retains all EMG rows.
    
    Parameters:
        original_emg (pd.DataFrame): Original EMG DataFrame.
        joined_df (pd.DataFrame): Joined DataFrame.
        join_description (str): Description of the join.
        debug (bool): If True, prints detailed debug information.
    """
    original_count = original_emg.shape[0]
    joined_count = joined_df.shape[0]
    if debug:
        print(f"[DEBUG] {join_description}: Original EMG rows = {original_count}, Joined rows = {joined_count}")
    if original_count != joined_count:
        print(f"[WARNING] EMG integrity check failed in {join_description}: original = {original_count}, joined = {joined_count}.")
    else:
        print(f"[DEBUG] {join_description}: All EMG rows retained.")




def check_interpolation_quality(resampled_df, original_biomech_df, emg_df, debug=False):
    """
    Check the quality of interpolation by analyzing:
    1. Distribution of distances to the nearest original biomech points.
    2. Proportion of EMG points within various tolerance thresholds.
    3. Statistical comparison of numeric values between original and interpolated rows.
    
    Parameters:
        resampled_df (pd.DataFrame): The resampled/interpolated biomech DataFrame.
        original_biomech_df (pd.DataFrame): The original biomech DataFrame.
        emg_df (pd.DataFrame): The EMG DataFrame.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        dict: Statistics about interpolation quality.
    """
    if resampled_df.empty:
        print("[WARNING] Empty resampled DataFrame, cannot check interpolation quality")
        return {}
    
    print("\n=== Interpolation Quality Check ===")
    
    # Get sorted original biomech timestamps as a numpy array
    original_times = original_biomech_df['datetime'].sort_values().values
    # Filter the interpolated rows
    interpolated_rows = resampled_df[resampled_df['is_interpolated']].copy()
    
    if interpolated_rows.empty:
        print("[INFO] No interpolated rows to check")
        return {}
    
    distances_ns = []
    CHUNK_SIZE = 1000  # Process in chunks to manage memory
    
    # Loop through the interpolated rows in chunks
    for i in range(0, len(interpolated_rows), CHUNK_SIZE):
        chunk = interpolated_rows.iloc[i:i+CHUNK_SIZE]
        for ts in chunk['datetime']:
            idx = np.searchsorted(original_times, ts)
            if idx == 0:
                distance = abs(int((ts - original_times[0]).total_seconds() * 1e9))
            elif idx == len(original_times):
                distance = abs(int((ts - original_times[-1]).total_seconds() * 1e9))
            else:
                distance = min(
                    abs(int((ts - original_times[idx-1]).total_seconds() * 1e9)),
                    abs(int((original_times[idx] - ts).total_seconds() * 1e9))
                )
            distances_ns.append(distance)
    
    # Convert distances from nanoseconds to milliseconds
    distances_ms = np.array(distances_ns) / 1_000_000
    
    print(f"[INFO] Distance statistics (ms) to nearest original biomech point:")
    print(f"  Min: {distances_ms.min():.3f}, Max: {distances_ms.max():.3f}")
    print(f"  Mean: {distances_ms.mean():.3f}, Median: {np.median(distances_ms):.3f}")
    print(f"  Std Dev: {distances_ms.std():.3f}")
    
    # Check proportions of points within several tolerance thresholds
    tolerances = [1, 5, 10, 50, 100]  # in milliseconds
    for tol in tolerances:
        within_tol = (distances_ms <= tol).sum()
        percent = (within_tol / len(distances_ms)) * 100
        print(f"  {within_tol} points ({percent:.2f}%) within {tol}ms of an original biomech point")
    
    # Compare numeric columns between original and interpolated rows (if debug is True)
    numeric_cols = resampled_df.select_dtypes(include=['number']).columns
    numeric_cols = [col for col in numeric_cols if col not in ['is_interpolated']]
    
    if debug and numeric_cols:
        print("[DEBUG] Value comparison for selected numeric columns:")
        for col in numeric_cols[:5]:  # Limit to first 5 columns for brevity
            original_values = resampled_df[~resampled_df['is_interpolated']][col]
            interpolated_values = resampled_df[resampled_df['is_interpolated']][col]
            
            if original_values.empty or interpolated_values.empty:
                continue
                
            print(f"  {col}:")
            print(f"    Original - Mean: {original_values.mean():.4f}, Std: {original_values.std():.4f}")
            print(f"    Interpolated - Mean: {interpolated_values.mean():.4f}, Std: {interpolated_values.std():.4f}")
    
    # Check how many EMG timestamps have a corresponding biomech row
    emg_timestamps = set(emg_df['datetime'])
    resampled_timestamps = set(resampled_df['datetime'])
    common_timestamps = emg_timestamps.intersection(resampled_timestamps)
    coverage = len(common_timestamps) / len(emg_timestamps) * 100
    print(f"[INFO] {len(common_timestamps)} of {len(emg_timestamps)} EMG timestamps ({coverage:.2f}%) have corresponding biomech data")
    
    return {
        'distances_ms': distances_ms,
        'coverage_percent': coverage,
        'interpolated_count': len(interpolated_rows),
        'original_count': len(resampled_df) - len(interpolated_rows)
    }


def save_dataframe(df, out_path, step_name, debug=False):
    """
    Save a DataFrame to a parquet file using the provided output path.
    
    Parameters:
        df (pd.DataFrame): DataFrame to save.
        out_path (str): Output file path.
        step_name (str): Name of the processing step (for logging).
        debug (bool): If True, prints detailed debug information.
    """
    df.to_parquet(out_path, index=False)
    if debug:
        print(f"[DEBUG] {step_name} saved to: {out_path}")
    else:
        print(f"{step_name} completed and saved.")


def deep_analysis(inner_df, debug=False):
    """
    Provides a concise deep analysis of the inner join results.
    
    Parameters:
        inner_df (pd.DataFrame): DataFrame resulting from the inner join.
        debug (bool): If True, prints detailed analysis.
    """
    print("\n=== Deep Analysis ===")
    print(f"Inner join shape: {inner_df.shape}")
    null_counts = inner_df.isnull().sum()
    if null_counts.sum() > 0:
        print("[WARNING] Null values found in inner join:")
        print(null_counts[null_counts > 0])
    else:
        print("No null values found in inner join.")
    dup_count = inner_df.duplicated(subset=["datetime"]).sum()
    if dup_count:
        print(f"[WARNING] Inner join has {dup_count} duplicate datetime rows.")
    else:
        print("No duplicate datetime rows in inner join.")
    print("[INFO] Summary statistics for time differences (seconds):")
    if "time_difference" in inner_df.columns:
        print(inner_df["time_difference"].dt.total_seconds().describe())


def analyze_temporal_alignment(biomech_df, emg_df, debug=False):
    """
    Analyze the temporal alignment between biomech and EMG data to identify viable join candidates.
    
    Parameters:
        biomech_df (pd.DataFrame): Biomech DataFrame.
        emg_df (pd.DataFrame): EMG DataFrame.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        dict: Statistics about temporal alignment.
    """
    print("\n=== Temporal Alignment Analysis ===")
    
    # Work on copies and add a 'date' column for daily analysis
    biomech_df = biomech_df.copy()
    emg_df = emg_df.copy()
    biomech_df['date'] = biomech_df['datetime'].dt.date
    emg_df['date'] = emg_df['datetime'].dt.date
    
    # Identify days present in each dataset and their intersection
    biomech_days = set(biomech_df['date'])
    emg_days = set(emg_df['date'])
    common_days = biomech_days.intersection(emg_days)
    
    print(f"[INFO] Days with biomech data: {len(biomech_days)}")
    print(f"[INFO] Days with EMG data: {len(emg_days)}")
    print(f"[INFO] Days with both data types: {len(common_days)}")
    if debug:
        print(f"[DEBUG] Common days: {sorted(common_days)}")
    
    # Initialize counters and distance bins
    counts = {
        'total_emg': 0,
        'viable_emg': 0,
        'distance_stats': {
            '1ms': 0,
            '5ms': 0,
            '10ms': 0,
            '50ms': 0,
            '100ms': 0,
            '500ms': 0,
            '1s': 0,
            '5s': 0,
            'other': 0
        }
    }
    
    # Process each common day
    for day in sorted(common_days):
        day_biomech = biomech_df[biomech_df['date'] == day].sort_values('datetime')
        day_emg = emg_df[emg_df['date'] == day].sort_values('datetime')
        day_emg_count = len(day_emg)
        counts['total_emg'] += day_emg_count
        
        # Skip if no biomech data for the day
        if len(day_biomech) == 0:
            continue
        
        # Convert datetime objects to Unix timestamps (seconds since epoch)
        # FIX: Use pandas methods directly to handle datetime64 objects properly
        biomech_times = day_biomech['datetime'].values
        biomech_unix_times = day_biomech['datetime'].map(pd.Timestamp.timestamp).values
        
        # Store original timestamps for calculating time differences later
        biomech_timestamps = day_biomech['datetime'].values
        
        closest_distances_ms = []
        chunk_size = 1000
        for i in range(0, len(day_emg), chunk_size):
            emg_chunk = day_emg.iloc[i:i+chunk_size]
            for _, emg_row in emg_chunk.iterrows():
                emg_ts = emg_row['datetime']
                emg_unix_ts = emg_ts.timestamp()
                
                # Search for the position in the sorted array using Unix timestamps
                idx = np.searchsorted(biomech_unix_times, emg_unix_ts)
                
                if idx == 0:
                    # EMG timestamp is before first biomech timestamp
                    closest_biomech = biomech_timestamps[0]
                    distance_ms = abs((emg_ts - pd.Timestamp(closest_biomech)).total_seconds() * 1000)
                elif idx == len(biomech_timestamps):
                    # EMG timestamp is after last biomech timestamp
                    closest_biomech = biomech_timestamps[-1]
                    distance_ms = abs((emg_ts - pd.Timestamp(closest_biomech)).total_seconds() * 1000)
                else:
                    # Find closest between previous and next biomech timestamps
                    prev_biomech = biomech_timestamps[idx-1]
                    next_biomech = biomech_timestamps[idx]
                    prev_distance = abs((emg_ts - pd.Timestamp(prev_biomech)).total_seconds() * 1000)
                    next_distance = abs((emg_ts - pd.Timestamp(next_biomech)).total_seconds() * 1000)
                    distance_ms = min(prev_distance, next_distance)
                
                closest_distances_ms.append(distance_ms)
                
                # Categorize the distance into bins
                if distance_ms <= 1:
                    counts['distance_stats']['1ms'] += 1
                    counts['viable_emg'] += 1
                elif distance_ms <= 5:
                    counts['distance_stats']['5ms'] += 1
                elif distance_ms <= 10:
                    counts['distance_stats']['10ms'] += 1
                elif distance_ms <= 50:
                    counts['distance_stats']['50ms'] += 1
                elif distance_ms <= 100:
                    counts['distance_stats']['100ms'] += 1
                elif distance_ms <= 500:
                    counts['distance_stats']['500ms'] += 1
                elif distance_ms <= 1000:
                    counts['distance_stats']['1s'] += 1
                elif distance_ms <= 5000:
                    counts['distance_stats']['5s'] += 1
                else:
                    counts['distance_stats']['other'] += 1
        
        if debug and closest_distances_ms:
            print(f"[DEBUG] Day {day}: {len(closest_distances_ms)} EMG points analyzed")
            print(f"[DEBUG] Day {day}: Distance stats (ms) - Min: {min(closest_distances_ms):.2f}, Max: {max(closest_distances_ms):.2f}, Mean: {np.mean(closest_distances_ms):.2f}")
    
    print("\n[INFO] Distance Distribution Summary:")
    for label, count in counts['distance_stats'].items():
        pct = (count / counts['total_emg'] * 100) if counts['total_emg'] > 0 else 0
        print(f"  {label}: {count} points ({pct:.2f}%)")
    print(f"\n[INFO] Total EMG points: {counts['total_emg']}")
    viable_pct = (counts['viable_emg'] / counts['total_emg'] * 100) if counts['total_emg'] > 0 else 0
    print(f"[INFO] EMG points within 1ms of biomech: {counts['viable_emg']} ({viable_pct:.2f}%)")
    
    return counts


def selective_resample_biomech_by_emg(biomech_df, emg_df, tolerance_ms=1, 
                                     max_analysis_distance_ms=5000,
                                     categorical_cols=None,
                                     categorical_numeric_cols=None,
                                     debug=False):
    """
    Selectively resample biomech data only for EMG timestamps that have nearby biomech data.
    This version first identifies viable candidates (within a maximum analysis distance) and then
    performs interpolation only for those specific timestamps.
    
    Parameters:
        biomech_df (pd.DataFrame): Filtered biomech DataFrame.
        emg_df (pd.DataFrame): EMG DataFrame.
        tolerance_ms (int): Tolerance in milliseconds for interpolation.
        max_analysis_distance_ms (int): Maximum distance (in ms) to consider for analysis.
        categorical_cols (list, optional): Columns to treat as categorical.
        categorical_numeric_cols (list, optional): Numeric columns to treat categorically.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        pd.DataFrame: Selectively resampled biomech DataFrame.
    """
    # Set default categorical columns if not provided
    if categorical_cols is None:
        categorical_cols = [
            'athlete_name', 'athlete_dob', 'athlete_traq',
            'athlete_level', 'lab', 'pitch_type', 'handedness', 'session_date',
            'height_meters', 'mass_kilograms',
            'lab', 'session', 'trial', 'pitch_type', 'handedness',
            'session_trial', 'pitch_speed_mph', 'date', 'time_step',
            'pitch_phase', 'session_date', 'time', 'session_time'
        ]
    if categorical_numeric_cols is None:
        categorical_numeric_cols = ['session', 'trial']

    # Work on copies and add date columns for daily processing
    biomech_df = biomech_df.copy()
    emg_df = emg_df.copy()
    biomech_df['date'] = biomech_df['datetime'].dt.date
    emg_df['date'] = emg_df['datetime'].dt.date
    
    # Calculate analysis buffer (in ns) and interpolation buffer (in ns)
    analysis_buffer_ns = max_analysis_distance_ms * 1_000_000
    interp_buffer_ns = tolerance_ms * 1_000_000
    
    if debug:
        print(f"[DEBUG] selective_resample: Starting with {len(biomech_df)} biomech rows and {len(emg_df)} EMG rows")
        print(f"[DEBUG] Biomech date range: {biomech_df['date'].min()} to {biomech_df['date'].max()}")
        print(f"[DEBUG] EMG date range: {emg_df['date'].min()} to {emg_df['date'].max()}")
        print(f"[DEBUG] Using interpolation tolerance of {tolerance_ms}ms ({interp_buffer_ns}ns)")
        print(f"[DEBUG] Using analysis distance of {max_analysis_distance_ms}ms ({analysis_buffer_ns}ns)")
    
    # Initialize counters
    interpolated_count = 0
    exact_match_count = 0
    outside_tolerance_count = 0
    no_biomech_data_count = 0
    resampled_list = []
    
    # Process only the common days
    common_days = set(biomech_df['date']).intersection(set(emg_df['date']))
    
    for day in sorted(common_days):
        emg_day = emg_df[emg_df['date'] == day]
        biomech_day = biomech_df[biomech_df['date'] == day]
        day_interpolated = 0
        day_exact_match = 0
        day_outside_tolerance = 0
        day_no_biomech_data = 0
        
        if biomech_day.empty:
            if debug:
                print(f"[DEBUG] selective_resample: No biomech data for day {day}. Skipping.")
            no_biomech_data_count += len(emg_day)
            continue
        
        if debug:
            print(f"[DEBUG] Day {day}: Processing {len(emg_day)} EMG rows and {len(biomech_day)} biomech rows")
        
        biomech_day_sorted = biomech_day.sort_values('datetime')
        emg_timestamps = emg_day['datetime'].sort_values().unique()
        
        if debug:
            print(f"[DEBUG] Day {day}: Found {len(emg_timestamps)} unique EMG timestamps")
        
        day_results = []
        original_timestamps = set(biomech_day_sorted['datetime'])
        processed_count = 0
        
        for emg_ts in emg_timestamps:
            processed_count += 1
            if debug and processed_count % 10000 == 0:
                print(f"[DEBUG] Day {day}: Processed {processed_count}/{len(emg_timestamps)} EMG timestamps")
            
            if emg_ts in original_timestamps:
                exact_row = biomech_day_sorted[biomech_day_sorted['datetime'] == emg_ts].copy()
                exact_row['is_interpolated'] = False
                day_results.append(exact_row)
                day_exact_match += 1
                continue
            
            before_mask = biomech_day_sorted['datetime'] < emg_ts
            after_mask = biomech_day_sorted['datetime'] > emg_ts
            if not before_mask.any() or not after_mask.any():
                day_no_biomech_data += 1
                continue
            
            before_idx = before_mask.values.nonzero()[0][-1]
            after_idx = after_mask.values.nonzero()[0][0]
            before_ts = biomech_day_sorted.iloc[before_idx]['datetime']
            after_ts = biomech_day_sorted.iloc[after_idx]['datetime']
            
            # Check if either gap is within the strict tolerance
            time_diff_before_ns = abs(int((emg_ts - before_ts).total_seconds() * 1e9))
            time_diff_after_ns = abs(int((after_ts - emg_ts).total_seconds() * 1e9))
            if time_diff_before_ns > interp_buffer_ns and time_diff_after_ns > interp_buffer_ns:
                day_outside_tolerance += 1
                continue
            
            before_row = biomech_day_sorted.iloc[before_idx].copy()
            after_row = biomech_day_sorted.iloc[after_idx].copy()
            new_row = before_row.copy()
            new_row['datetime'] = emg_ts
            new_row['is_interpolated'] = True
            
            total_time_diff_ns = int((after_ts - before_ts).total_seconds() * 1e9)
            for col in biomech_day_sorted.columns:
                if col in ['datetime', 'is_interpolated', 'date'] or col in categorical_cols:
                    continue
                if col in categorical_numeric_cols:
                    new_row[col] = before_row[col] if time_diff_before_ns <= time_diff_after_ns else after_row[col]
                else:
                    try:
                        before_val = pd.to_numeric(before_row[col])
                        after_val = pd.to_numeric(after_row[col])
                        if total_time_diff_ns > 0:
                            position = time_diff_before_ns / total_time_diff_ns
                            new_row[col] = before_val + position * (after_val - before_val)
                    except (ValueError, TypeError):
                        new_row[col] = before_row[col] if time_diff_before_ns <= time_diff_after_ns else after_row[col]
            day_results.append(pd.DataFrame([new_row]))
            day_interpolated += 1
        
        if day_results:
            day_df = pd.concat(day_results, ignore_index=True)
            if debug:
                print(f"[DEBUG] Day {day}: Created {len(day_df)} biomech rows (exact: {day_exact_match}, interp: {day_interpolated})")
                print(f"[DEBUG] Day {day}: Skipped {day_outside_tolerance} (outside tolerance), {day_no_biomech_data} (no suitable biomech data)")
            resampled_list.append(day_df)
        else:
            if debug:
                print(f"[DEBUG] Day {day}: No valid biomech data created")
        interpolated_count += day_interpolated
        exact_match_count += day_exact_match
        outside_tolerance_count += day_outside_tolerance
        no_biomech_data_count += day_no_biomech_data
    
    if resampled_list:
        resampled_biomech = pd.concat(resampled_list, ignore_index=True)
        total_processed = interpolated_count + exact_match_count + outside_tolerance_count + no_biomech_data_count
        total_retained = interpolated_count + exact_match_count
        retention_rate = (total_retained / total_processed * 100) if total_processed > 0 else 0
        
        print(f"\n=== Resampling Summary ===")
        print(f"[INFO] EMG timestamps processed: {total_processed}")
        print(f"[INFO] EMG timestamps retained: {total_retained} ({retention_rate:.2f}%)")
        print(f"[INFO] Exact matches: {exact_match_count}")
        print(f"[INFO] Interpolated points: {interpolated_count}")
        print(f"[INFO] Skipped (outside tolerance): {outside_tolerance_count}")
        print(f"[INFO] Skipped (no suitable biomech data): {no_biomech_data_count}")
        if debug:
            mem_usage = resampled_biomech.memory_usage(deep=True).sum() / (1024**2)
            print(f"[DEBUG] Final resampled biomech shape = {resampled_biomech.shape}")
            print(f"[DEBUG] Resampled biomech memory usage: {mem_usage:.2f} MB")
    else:
        resampled_biomech = pd.DataFrame()
        print("[WARNING] No biomech data could be resampled for any EMG timestamps within tolerance")
    
    return resampled_biomech


def strict_inner_join(emg_df, resampled_biomech, tolerance_ms=1, 
                     expected_exclusions=None,
                     debug=False):
    """
    Perform a strict merge_asof inner join using the EMG dataset as the base.
    This version maintains a strict tolerance and only retains EMG rows that have
    perfectly matched biomech data within the specified tolerance.
    
    Parameters:
        emg_df (pd.DataFrame): EMG DataFrame.
        resampled_biomech (pd.DataFrame): Resampled biomech DataFrame.
        tolerance_ms (int): Tolerance in milliseconds for joining.
        expected_exclusions (list, optional): Columns to exclude when checking for missing values.
        debug (bool): If True, prints detailed debug information.
        
    Returns:
        pd.DataFrame: Validated inner join DataFrame.
    """
    if expected_exclusions is None:
        expected_exclusions = ["datetime", "biomech_datetime"]

    tolerance = pd.Timedelta(f"{tolerance_ms}ms")
    
    if debug:
        print(f"[DEBUG] strict_inner_join: EMG shape = {emg_df.shape}, Biomech shape = {resampled_biomech.shape}")
        print(f"[DEBUG] strict_inner_join: Using tolerance = {tolerance}")
    
    if resampled_biomech.empty:
        print("[WARNING] No biomech data available for joining. Returning empty DataFrame.")
        return pd.DataFrame()
    
    CHUNK_SIZE = 10000
    chunks = []
    total_emg_rows = len(emg_df)
    total_joined_rows = 0
    total_no_match = 0
    total_missing_data = 0
    
    emg_sorted = emg_df.sort_values('datetime')
    biomech_sorted = resampled_biomech.sort_values('datetime')
    
    for i in range(0, len(emg_sorted), CHUNK_SIZE):
        emg_chunk = emg_sorted.iloc[i:i+CHUNK_SIZE].copy()
        chunk_size = len(emg_chunk)
        if debug and i > 0:
            print(f"[DEBUG] strict_inner_join: Processing chunk {i//CHUNK_SIZE + 1} of {(len(emg_sorted) + CHUNK_SIZE - 1)//CHUNK_SIZE} ({chunk_size} rows)")
        
        min_ts = emg_chunk['datetime'].min() - tolerance
        max_ts = emg_chunk['datetime'].max() + tolerance
        biomech_subset = biomech_sorted[
            (biomech_sorted['datetime'] >= min_ts) & 
            (biomech_sorted['datetime'] <= max_ts)
        ].copy()
        
        if biomech_subset.empty:
            if debug:
                print(f"[DEBUG] strict_inner_join: No biomech data for EMG chunk {i//CHUNK_SIZE + 1}")
            total_no_match += chunk_size
            continue
        
        chunk_joined = pd.merge_asof(
            emg_chunk,
            biomech_subset,
            on="datetime",
            direction="nearest",
            tolerance=tolerance
        )
        expected_cols = [col for col in biomech_subset.columns if col not in expected_exclusions]
        before_drop = len(chunk_joined)
        valid_chunk = chunk_joined.dropna(subset=expected_cols)
        dropped_missing = before_drop - len(valid_chunk)
        total_missing_data += dropped_missing
        
        if debug and dropped_missing > 0:
            print(f"[DEBUG] strict_inner_join: Dropped {dropped_missing} rows in chunk {i//CHUNK_SIZE + 1} due to missing biomech data")
        
        if not valid_chunk.empty:
            valid_chunk["time_difference"] = (valid_chunk["datetime"] - valid_chunk["biomech_datetime"]).abs()
            # Only retain rows strictly within the tolerance
            valid_within_tolerance = valid_chunk[valid_chunk["time_difference"] <= tolerance]
            dropped_tolerance = len(valid_chunk) - len(valid_within_tolerance)
            total_no_match += dropped_tolerance
            if debug and dropped_tolerance > 0:
                print(f"[DEBUG] strict_inner_join: Dropped {dropped_tolerance} rows exceeding tolerance in chunk {i//CHUNK_SIZE + 1}")
            if not valid_within_tolerance.empty:
                chunks.append(valid_within_tolerance)
                total_joined_rows += len(valid_within_tolerance)
    
    if chunks:
        valid_join = pd.concat(chunks, ignore_index=True)
        if debug:
            print(f"[DEBUG] strict_inner_join: Final join shape = {valid_join.shape}")
            print(f"[DEBUG] strict_inner_join: Time difference statistics (seconds):")
            print(valid_join["time_difference"].dt.total_seconds().describe())
        
        retention_rate = (total_joined_rows / total_emg_rows) * 100
        print(f"\n=== Join Summary ===")
        print(f"[INFO] Total EMG rows: {total_emg_rows}")
        print(f"[INFO] Successfully joined rows: {total_joined_rows} ({retention_rate:.2f}%)")
        print(f"[INFO] Rows with no matching biomech data: {total_no_match}")
        print(f"[INFO] Rows with missing required biomech columns: {total_missing_data}")
        return valid_join
    else:
        print("[WARNING] No valid joined rows could be created with the specified strict tolerance")
        return pd.DataFrame()



if __name__ == "__main__":
    biomech_path=None
    emg_path=None
    output_dir=None
    debug=True
    """
    Main function to run the EMG-Biomech inner join workflow with strict tolerance.
    
    Parameters:
        biomech_path (str): Path to biomech data file.
        emg_path (str): Path to EMG data file.
        output_dir (str): Directory for output files.
        debug (bool): If True, prints detailed debug information.
    """
    # Use default paths if not provided
    if biomech_path is None:
        biomech_path = "../../data/processed/ml_datasets/granular/granular_joint_details.parquet"
    if emg_path is None:
        emg_path = "../../data/processed/combined_emg_data.parquet"
    if output_dir is None:
        output_dir = "../../data/processed/ml_datasets"
    
    # Configurable parameters – strict tolerance version
    tolerance_ms = 1         # Strict tolerance in milliseconds for joining
    buffer_minutes = 30      # Buffer for filtering biomech data around EMG data ranges
    
    print(f"[INFO] === Configuration ===")
    print(f"[INFO] Join tolerance: {tolerance_ms}ms (strict)")
    print(f"[INFO] Buffer around EMG data: {buffer_minutes} minutes")

    # ---------------- Load Data ----------------
    biomech_df = load_biomech_data(biomech_path, debug=debug)
    emg_df = load_emg_data(emg_path, debug=debug)
    
    if debug:
        print("\n[DEBUG] Biomech DataFrame columns and null counts:")
        for col in biomech_df.columns:
            print(f" - {col}: {biomech_df[col].isnull().sum()} nulls")
        print(f"[DEBUG] Total number of columns: {len(biomech_df.columns)}")
        print(f"[DEBUG] Total number of null values: {biomech_df.isnull().sum().sum()}")
    
    print("\n[INFO] Performing initial biomech dataset checks...")
    if biomech_df.isnull().sum().sum() > 0:
        print("[INFO] Dropping rows with null values from biomech data...")
        biomech_df = biomech_df.dropna()
    print(f"[INFO] Biomech data now has {len(biomech_df)} rows.")

    # Drop unnecessary columns from EMG data
    emg_columns_to_drop = [
        'EMG 1 (mV) - FDS', 'ACC X (G) - FDS', 'ACC Y (G) - FDS', 
        'ACC Z (G) - FDS','GYRO X (deg/s) - FDS','GYRO Y (deg/s) - FDS', 
        'GYRO Z (deg/s) - FDS','ACC X (G) - FCU','ACC Y (G) - FCU',
        'ACC Z (G) - FCU','GYRO X (deg/s) - FCU','GYRO Y (deg/s) - FCU',
        'GYRO Z (deg/s) - FCU','ACC X (G) - FCR','ACC Y (G) - FCR',
        'ACC Z (G) - FCR','GYRO X (deg/s) - FCR','GYRO Y (deg/s) - FCR',
        'GYRO Z (deg/s) - FCR'
    ]
    existing_cols_to_drop = [col for col in emg_columns_to_drop if col in emg_df.columns]
    if existing_cols_to_drop:
        print(f"[INFO] Dropping {len(existing_cols_to_drop)} unnecessary columns from EMG data")
        emg_df = emg_df.drop(columns=existing_cols_to_drop)
    
    total_nulls = emg_df.isnull().sum().sum()
    if total_nulls > 0:
        print(f"[INFO] Dropping {total_nulls} null values from EMG data...")
        emg_df = emg_df.dropna()
    print(f"[INFO] EMG data now has {len(emg_df)} rows.")

    # ---------------- Compute Time Steps & Sort ----------------
    biomech_df = compute_time_steps(biomech_df, debug=debug)
    emg_df = compute_time_steps(emg_df, debug=debug)
    biomech_df, emg_df = sort_dataframes(biomech_df, emg_df, debug=debug)

    # ---------------- Filter Biomech Data by EMG Date Ranges ----------------
    filtered_biomech = filter_biomech_by_emg_range(biomech_df, emg_df, buffer_minutes=buffer_minutes, debug=debug)
    
    # ---------------- Analyze Temporal Alignment ----------------
    print("\n[INFO] Analyzing temporal alignment between EMG and biomech data...")
    alignment_stats = analyze_temporal_alignment(filtered_biomech, emg_df, debug=debug)
    if alignment_stats['viable_emg'] == 0:
        print("[ERROR] No EMG points found within the strict tolerance of biomech data.")
        print("[INFO] Consider analyzing the temporal distribution of your data and trying again.")

    
    # ---------------- Selective Resampling ----------------
    print("\n[INFO] Performing selective resampling with strict tolerance...")
    resampled_biomech = selective_resample_biomech_by_emg(
        filtered_biomech, 
        emg_df, 
        tolerance_ms=tolerance_ms,
        debug=debug
    )
    
    if resampled_biomech.empty:
        print("[ERROR] Resampling produced no usable data. No viable EMG-biomech pairs within tolerance.")

    
    resampled_biomech["biomech_datetime"] = resampled_biomech["datetime"].copy()
    resampled_biomech = resampled_biomech.rename(
        columns={col: f"{col}_biomech" for col in resampled_biomech.columns 
                 if col not in ["datetime", "biomech_datetime", "is_interpolated"]}
    )

    # ---------------- Strict Inner Join ----------------
    print("\n[INFO] Performing strict inner join...")
    joined_df = strict_inner_join(
        emg_df, 
        resampled_biomech, 
        tolerance_ms=tolerance_ms,
        debug=debug
    )
    
    if joined_df.empty:
        print("[ERROR] Join produced no valid rows with the strict tolerance.")

    
    final_join_path = f"{output_dir}/final_inner_join_emg_biomech_data.parquet"
    save_dataframe(joined_df, final_join_path, "Final strict inner join dataset", debug=debug)
    
    sample_rows = joined_df.head(5)
    sample_path = f"{output_dir}/sample_emg_biomech_data.parquet"
    save_dataframe(sample_rows, sample_path, "Sample rows from final joined dataset", debug=debug)
    
    deep_analysis(joined_df, debug=debug)
    
    counts = resampled_biomech["is_interpolated"].value_counts()
    interp_count = counts.get(True, 0)
    original_count = counts.get(False, 0)
    print(f"[INFO] Interpolated datapoints: {interp_count}; Original datapoints: {original_count}")
    print("\n[INFO] Process completed successfully.")
    # return joined_df
 





# Checks section

# Checks into the pitch phases to understand them better:

INFO:root:Aggregated Phase Duration Statistics:
INFO:root:Phase: Follow Through -> Avg: 27.623s, Min: 2.036s***, Max: 124.441s****, Std: 22.539s*******
INFO:root:Phase: Wind-Up -> Avg: 0.802s, Min: 0.587s, Max: 0.936s, Std: 0.085s
INFO:root:Phase: Stride -> Avg: 0.628s, Min: 0.567s, Max: 0.752s, Std: 0.047s
INFO:root:Phase: Arm Cocking -> Avg: 0.103s, Min: 0.085s, Max: 0.119s, Std: 0.009s
INFO:root:Phase: Arm Acceleration -> Avg: 0.047s, Min: 0.005s, Max: 0.060s, Std: 0.015s


In [None]:
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Load the dataset
logger.info("Loading final inner join EMG biomech dataset...")
df = pd.read_parquet('../../data/processed/ml_datasets/final_inner_join_emg_biomech_data 2.parquet')

# Print all columns
logger.info("\nColumns in dataset:")
logger.info(df.columns.tolist())

# Check unique pitch types
logger.info("\nUnique pitch types:")
logger.info(df['pitch_type_biomech'].unique())
logger.info("\nPitch type counts:")
logger.info(df['pitch_type_biomech'].value_counts())

# Get basic statistics for all columns
logger.info("\nDataset description:")
logger.info(df.describe())
logger.info(f"DataFrame shape: {df.shape}")

# # Check ball_released_biomech column
# logger.info("\nChecking ball_released_biomech column:")
# logger.info(f"Number of True values: {df['ball_released_biomech'].sum()}")
# logger.info(f"Number of False values: {(~df['ball_released_biomech']).sum()}")
# logger.info(f"Number of null values: {df['ball_released_biomech'].isnull().sum()}")
# logger.info("\nValue distribution:")
# logger.info(df['ball_released_biomech'].value_counts(normalize=True))

# Basic checks for zeros and nulls in pitch_phase_biomech
logger.info("\nBasic phase checks:")
zeros_in_phase = (df['pitch_phase_biomech'] == 0).sum()
logger.info(f"Number of zeros in pitch_phase: {zeros_in_phase}")

nulls_in_phase = df['pitch_phase_biomech'].isnull().sum()
logger.info(f"Number of nulls in pitch_phase_biomech: {nulls_in_phase}")

logger.info("\nValue counts in pitch_phase:")
logger.info(df['pitch_phase_biomech'].value_counts())

# Check for nulls in all columns
null_counts = df.isnull().sum()
columns_with_nulls = null_counts[null_counts > 0]
if len(columns_with_nulls) > 0:
    logger.info("\nColumns containing null values:")
    logger.info(columns_with_nulls)
else:
    logger.info("\nNo null values found in any columns")

# Ensure datetime is parsed and data is sorted globally
df['datetime'] = pd.to_datetime(df['biomech_datetime'])
df = df.sort_values('datetime')

logger.info("\nChecking for overlapping pitch phases, time gaps and phase durations per trial...")

# Group data by trial
trial_groups = df.groupby('trial_biomech')

# Loop over each trial
for trial, trial_data in trial_groups:
    # Sort data for the current trial by datetime
    trial_data = trial_data.sort_values('datetime')
    
    # Create a summary of phases: start and end times per phase
    phase_summary = trial_data.groupby('pitch_phase_biomech').agg(
        phase_start=('datetime', 'min'),
        phase_end=('datetime', 'max')
    ).reset_index().sort_values('phase_start')
    
    # Calculate duration for each phase
    phase_summary['duration'] = phase_summary['phase_end'] - phase_summary['phase_start']
    
    logger.info(f"\nTrial: {trial}")
    logger.info("Phase timeline summary:")
    for idx, row in phase_summary.iterrows():
        logger.info(f"  - {row['pitch_phase_biomech']}: start: {row['phase_start']}, end: {row['phase_end']}, duration: {row['duration']}")
    
    # Check for overlaps and gaps between consecutive phases
    for i in range(1, len(phase_summary)):
        prev_phase = phase_summary.iloc[i - 1]
        curr_phase = phase_summary.iloc[i]
        
        # If previous phase's end is after the current phase's start, there's an overlap.
        if prev_phase['phase_end'] > curr_phase['phase_start']:
            overlap_duration = prev_phase['phase_end'] - curr_phase['phase_start']
            logger.info(
                f"Overlap detected between '{prev_phase['pitch_phase_biomech']}' and "
                f"'{curr_phase['pitch_phase_biomech']}': {overlap_duration} (duration)"
            )
        else:
            # Otherwise, compute the gap (time between phases)
            gap_duration = curr_phase['phase_start'] - prev_phase['phase_end']
            logger.info(
                f"Gap between '{prev_phase['pitch_phase_biomech']}' and "
                f"'{curr_phase['pitch_phase_biomech']}': {gap_duration} (time between phases)"
            )
    
    # Optional: Within each trial, also check if any phase internally has overlapping points from a different phase.
    phase_groups = trial_data.groupby('pitch_phase_biomech')
    for phase in trial_data['pitch_phase_biomech'].unique():
        phase_data = phase_groups.get_group(phase)
        phase_start = phase_data['datetime'].min()
        phase_end = phase_data['datetime'].max()
        overlapping = {}
        for other_phase in trial_data['pitch_phase_biomech'].unique():
            if other_phase == phase:
                continue
            other_phase_data = phase_groups.get_group(other_phase)
            overlap_points = other_phase_data[
                (other_phase_data['datetime'] >= phase_start) &
                (other_phase_data['datetime'] <= phase_end)
            ]
            if not overlap_points.empty:
                overlapping[other_phase] = len(overlap_points)
        if overlapping:
            logger.info(f"In trial {trial}, phase '{phase}' overlaps with:")
            for other_phase, count in overlapping.items():
                logger.info(f"  - '{other_phase}': {count} overlapping points")
        else:
            logger.info(f"In trial {trial}, phase '{phase}' has no overlapping points with other phases.")


import logging
import statistics
from datetime import timedelta
from functools import wraps

def require_array_type(func):
    """
    Decorator that asserts the second argument (usually the data input)
    has a 'shape' attribute, i.e. is array-like.
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        # args[1] should be the phase_data input for _align_phase
        if not hasattr(args[1], 'shape'):
            raise TypeError(f"Function {func.__name__} requires array-like input")
        return func(*args, **kwargs)
    return wrapper

# --- Modified/Added Functions ---

def print_phase_duration_statistics(durations_by_phase):
    """
    Computes and logs statistics for each phase:
        - Average duration
        - Minimum duration
        - Maximum duration
        - Standard deviation of durations
    Also computes the overall min and max standard deviation across phases.
    
    Parameters:
        durations_by_phase (dict): Dictionary where keys are phase names and values
                                   are lists of durations (in seconds).
    """
    logging.info("Aggregated Phase Duration Statistics:")
    std_devs = {}  # To store standard deviation for each phase

    for phase, durations in durations_by_phase.items():
        avg_duration = statistics.mean(durations)
        min_duration = min(durations)
        max_duration = max(durations)
        # Calculate standard deviation; if only one duration, define std as 0.
        std_duration = statistics.stdev(durations) if len(durations) > 1 else 0.0
        std_devs[phase] = std_duration

        logging.info(
            f"Phase: {phase} -> Avg: {avg_duration:.3f}s, "
            f"Min: {min_duration:.3f}s, Max: {max_duration:.3f}s, Std: {std_duration:.3f}s"
        )

    if std_devs:
        overall_min_std = min(std_devs.values())
        

def process_trials(trials):
    """
    Process each trial to print phase timeline summaries.
    At the same time, collect phase durations for later statistics.
    
    Parameters:
        trials (list): List of trial objects/dictionaries. Each trial is expected
                       to have a 'phases' dictionary with keys as phase names and
                       values as a dict containing 'start' and 'end' (datetime objects).
    """
    # Dictionary to hold durations for each phase (in seconds)
    durations_by_phase = {}
    
    for trial in trials:
        trial_id = trial.get("id", "Unknown")
        logging.info(f"Trial: {trial_id}")
        logging.info("Phase timeline summary:")
        
        # Assume trial['phases'] is a dict like: { 'Phase Name': {'start': dt, 'end': dt}, ... }
        phases = trial.get("phases", {})
        # Sort phases by start time (if needed)
        sorted_phases = sorted(phases.items(), key=lambda item: item[1]['start'])
        
        for phase, timeline in sorted_phases:
            start = timeline['start']
            end = timeline['end']
            # Calculate duration as a timedelta and then convert to seconds
            duration_td = end - start
            duration_sec = duration_td.total_seconds()
            logging.info(f"  - {phase}: start: {start}, end: {end}, duration: {duration_td}")
            
            # Collect durations
            durations_by_phase.setdefault(phase, []).append(duration_sec)
        
        # (If your code also prints gap information, leave that logic unchanged)
        # For example:
        # logging.info(f"Gap between 'Phase1' and 'Phase2': {gap_td} (time between phases)")
    
    # After processing all trials, print aggregated statistics:
    print_phase_duration_statistics(durations_by_phase)

def process_trials_from_df(df, trial_id_col='trial_biomech', phase_col='pitch_phase_biomech', datetime_col='datetime'):
    durations_by_phase = {}
    trial_groups = df.groupby(trial_id_col)
    
    for trial_id, trial_data in trial_groups:
        logging.info(f"Trial: {trial_id}")
        logging.info("Phase timeline summary:")
        
        # Group by phase within this trial
        phase_summary = trial_data.groupby(phase_col).agg(
            phase_start=(datetime_col, 'min'),
            phase_end=(datetime_col, 'max')
        ).reset_index().sort_values('phase_start')
        
        for idx, row in phase_summary.iterrows():
            start = row['phase_start']
            end = row['phase_end']
            duration_td = end - start
            duration_sec = duration_td.total_seconds()
            logging.info(f"  - {row[phase_col]}: start: {start}, end: {end}, duration: {duration_td}")
            
            durations_by_phase.setdefault(row[phase_col], []).append(duration_sec)
        
        # (Optional: Check for overlaps and gaps as before)
    
    print_phase_duration_statistics(durations_by_phase)



process_trials_from_df(df, phase_col='pitch_phase_biomech')


# Count unique trials
num_trials = df['trial_biomech'].nunique()
logger.info(f"\nTotal number of unique trials: {num_trials}")

# Optional: Show distribution of data points across trials
trial_counts = df.groupby('trial_biomech').size()
logger.info("\nData points per trial:")
logger.info(f"Mean: {trial_counts.mean():.2f}")
logger.info(f"Min: {trial_counts.min()}")
logger.info(f"Max: {trial_counts.max()}")
logger.info(f"Std: {trial_counts.std():.2f}")

# Optional: Show trial distribution by session
session_trial_counts = df.groupby('session_biomech')['trial_biomech'].nunique()
logger.info("\nTrials per session:")
logger.info(session_trial_counts) 

# Data Checks combine and filter Checks for EMG Signals

Create Dashboard for monitoring this data flow so we are ensuring now too many low signals and we ensure we capture all the high signals


Summarized emg dataset show KPI's and validation metrics

In [None]:
import numpy as np
import pandas as pd
from scipy.signal import welch, butter, filtfilt
import matplotlib.pyplot as plt
# Ensure fft functions are imported for FFT analysis
from scipy.fft import fft, fftfreq

# Butterworth band-pass filter
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    if low <= 0 or high >= 1 or low >= high:
        raise ValueError(f"Check critical frequencies: got low={lowcut}, high={highcut} Hz.")
    b, a = butter(order, [low, high], btype='band')
    return b, a

# Band-pass filter function (missing from your original script)
def bandpass_filter(emg_signal, lowcut=20, highcut=500, fs=1000, order=4):
    print(f"Applying bandpass filter from {lowcut}-{highcut} Hz")
    b, a = butter_bandpass(lowcut, highcut, fs, order)
    return filtfilt(b, a, emg_signal)


def compute_emg_features(emg_signal, fs):
    """
    Compute common EMG features from a single segment of raw EMG data.
    Parameters
    ----------
    emg_signal : 1D numpy array
        Raw EMG data samples for a single time window/segment.
    fs : float
        Sampling frequency of the EMG data (in Hz).
    Returns
    -------
    features : dict
        A dictionary of computed EMG features:
          Time-Domain:

'RMS': Root Mean Square
'MAV': Mean Absolute Value
'IEMG': Integrated EMG
'Variance': Signal Variance
'ZeroCrossings': Zero Crossing Count
'WaveformLength': Waveform Length
          Frequency-Domain:

'MeanFrequency': Mean Frequency of the power spectrum
'MedianFrequency': Median Frequency of the power spectrum
'PeakFrequency': Frequency at which power is highest
    """
    # ---------------------
    # 1. TIME-DOMAIN FEATURES
    # ---------------------
    # (A) Root Mean Square (RMS)
    rms_val = np.sqrt(np.mean(emg_signal**2))
    # (B) Mean Absolute Value (MAV)
    mav_val = np.mean(np.abs(emg_signal))
    # (C) Integrated EMG (IEMG)
    iemg_val = np.sum(np.abs(emg_signal))
    # (D) Variance
    # ddof=1 makes it the sample variance
    var_val = np.var(emg_signal, ddof=1)
    # (E) Zero Crossing Count (ZCR)
    zero_crossings = 0
    for i in range(len(emg_signal) - 1):
        if np.sign(emg_signal[i]) != np.sign(emg_signal[i+1]):
            zero_crossings += 1
    # (F) Waveform Length (WL)
    waveform_length = np.sum(np.abs(np.diff(emg_signal)))
    # ---------------------
    # 2. FREQUENCY-DOMAIN FEATURES
    # ---------------------
    #
    # We can use the Welch method to estimate the Power Spectral Density (PSD).
    # 'welch' splits the signal into sub-segments to get a smoother PSD estimate.
    freqs, psd = welch(emg_signal, fs=fs, nperseg=len(emg_signal))
    # psd[i] is the power at frequency freqs[i].
    # (A) Mean Frequency (MNF)
    # = (sum of frequency * PSD) / (sum of PSD)
    total_power = np.sum(psd)
    mean_frequency = np.sum(freqs * psd) / total_power if total_power > 0 else 0
    # (B) Median Frequency (MDF)
    # = frequency where half of the total power lies below it
    cumulative_power = np.cumsum(psd)
    half_power = cumulative_power[-1] / 2.0
    median_freq_idx = np.where(cumulative_power >= half_power)[0][0]
    median_frequency = freqs[median_freq_idx]
    # (C) Peak Frequency
    # = frequency at which the PSD is maximum
    peak_idx = np.argmax(psd)
    peak_frequency = freqs[peak_idx]
    # ---------------------
    # 3. COLLECT AND RETURN FEATURES
    # ---------------------
    features = {
        # Time-Domain
        'RMS': rms_val,
        'MAV': mav_val,
        'IEMG': iemg_val,
        'Variance': var_val,
        'ZeroCrossings': zero_crossings,
        'WaveformLength': waveform_length,
        # Frequency-Domain
        'MeanFrequency': mean_frequency,
        'MedianFrequency': median_frequency,
        'PeakFrequency': peak_frequency
    }
    return features



# Butterworth high-pass filter to remove low-frequency noise below the specified cutoff
def butter_highpass(lowcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    print(f"Normalized lowcut frequency: {low}")  # Debugging
    if low <= 0:
        raise ValueError(f"Critical frequency must be greater than 0. Got lowcut={low}.")
    b, a = butter(order, low, btype='high')
    return b, a

def highpass_filter(emg_signal, lowcut=20, fs=1000, order=4):
    print(f"Applying highpass filter with lowcut={lowcut} Hz, fs={fs} Hz")  # Debugging
    b, a = butter_highpass(lowcut, fs, order)
    return filtfilt(b, a, emg_signal)




# Example preprocessing and feature extraction pipeline
def process_and_compute_features(emg_signal, fs, highpass=True):
    # 1. Remove excessive zeros (if needed, set threshold)
    if (emg_signal == 0).sum() > len(emg_signal) * 0.1:  # If more than 10% zeros, discard
        return None
    
    # If highpass is True, apply highpass filter, otherwise, skip it
    if highpass:
        # 2. High-pass filter to remove low-frequency noise (20 Hz range)
        # Collect statistics before filtering
        pre_filter_stats = {
            'min': np.min(emg_signal),
            'max': np.max(emg_signal),
            'mean': np.mean(emg_signal),
            'std': np.std(emg_signal),
        }

        # Apply the high-pass filter
        filtered_signal = highpass_filter(emg_signal, lowcut=20, fs=fs)

        # Collect statistics after filtering
        post_filter_stats = {
            'min': np.min(filtered_signal),
            'max': np.max(filtered_signal),
            'mean': np.mean(filtered_signal),
            'std': np.std(filtered_signal),
        }

        # Compare pre and post filter statistics
        print("Pre-filter statistics:", pre_filter_stats)
        print("Post-filter statistics:", post_filter_stats)

    else:
        # If no filter is applied, use the original signal
        filtered_signal = emg_signal

    # 3. Compute features from filtered (or original) signal
    return compute_emg_features(filtered_signal, fs)




        

#---------------Checks---------
def plot_psd_before_and_after(emg_signal, fs, lowcut=20, highcut=500):
    """
    Plot Power Spectral Density (PSD) before and after filtering to compare frequency content.
    Parameters:
    ----------
    emg_signal : 1D numpy array
        Raw EMG signal data.
    fs : float
        Sampling frequency of the EMG data (in Hz).
    lowcut : float
        Lowcut frequency for bandpass filter (Hz).
    highcut : float
        Highcut frequency for bandpass filter (Hz).
    """
    # Plot before filtering
    freqs_before, psd_before = welch(emg_signal, fs=fs, nperseg=len(emg_signal))
    
    # Apply bandpass filter
    filtered_signal = bandpass_filter(emg_signal, lowcut=lowcut, highcut=highcut, fs=fs)
    
    # Plot after filtering
    freqs_after, psd_after = welch(filtered_signal, fs=fs, nperseg=len(filtered_signal))
    
    # Plot both PSDs
    plt.figure(figsize=(12, 6))
    plt.semilogy(freqs_after, psd_after, label='After Filtering', color='red')
    plt.semilogy(freqs_before, psd_before, label='Before Filtering', color='blue')
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Power Spectral Density (uV^2/Hz)")
    plt.title("Power Spectral Density Before and After Filtering")
    plt.legend()
    plt.grid(True)
    plt.show()



# Load EMG pitch data for validation not for use really
emg_pitch_data = pd.read_parquet('../data/processed/emg_pitch_data_processed.parquet')
# print("\nEMG Pitch Data Columns and Unique Values:")
# for col in emg_pitch_data.columns:
#     print(f"\n{col}:")
#     print(f"Unique values: {emg_pitch_data[col].unique()}")
    


# Get all EMG columns
emg_signal_columns = ['EMG 1 (mV) - FDS (81770)', 'EMG 1 (mV) - FCU (81728)', 'EMG 1 (mV) - FCR (81745)']
print("EMG Signals here ===========", emg_signal_columns)
fs = 1000  # Sampling frequency (adjust based on your setup)

# Print detailed information about each EMG column
for column in emg_signal_columns:
    print(f"\n=== Analysis of {column} ===")
    
    # Basic statistics
    print("\nBasic Statistics:")
    print(f"Number of values: {len(emg_pitch_data[column])}")
    print(f"Number of null values: {emg_pitch_data[column].isnull().sum()}")
    print(f"Data type: {emg_pitch_data[column].dtype}")
    
    # Value analysis
    print("\nValue Analysis:")
    first_value = emg_pitch_data[column].iloc[0]
    print(f"First value type: {type(first_value)}")
    print(f"Shape (if array): {np.shape(first_value) if hasattr(first_value, 'shape') else 'N/A'}")
    print(f"First few values: {emg_pitch_data[column].head().values}")
    
    # Numerical statistics if possible
    try:
        print("\nNumerical Statistics:")
        print(f"Mean: {emg_pitch_data[column].mean():.4f}")
        print(f"Std: {emg_pitch_data[column].std():.4f}")
        print(f"Min: {emg_pitch_data[column].min():.4f}")
        print(f"Max: {emg_pitch_data[column].max():.4f}")
        
        # Check for potential issues
        print("\nPotential Issues:")
        print(f"Number of zeros: {(emg_pitch_data[column] == 0).sum()}")
        print(f"Number of unique values: {emg_pitch_data[column].nunique()}")
        
    except Exception as e:
        print(f"\nCould not compute numerical statistics: {str(e)}")

# Process each EMG signal
features_dict = {}
# Create separate feature dictionaries for filtered and unfiltered data
filtered_features = {}
unfiltered_features = {}

for signal_col in emg_signal_columns:
    # Extract raw EMG signal
    emg_signal = emg_pitch_data[signal_col].values
    
    # Process with highpass filter
    filtered = process_and_compute_features(emg_signal, fs, highpass=True)
    if filtered:
        filtered_features[signal_col] = filtered
    else:
        print(f"Filtered signal {signal_col} was discarded due to excessive zeros.")
        
    # Process without highpass filter 
    unfiltered = process_and_compute_features(emg_signal, fs, highpass=False)
    if unfiltered:
        unfiltered_features[signal_col] = unfiltered
    else:
        print(f"Unfiltered signal {signal_col} was discarded due to excessive zeros.")

# Compare and validate features
for signal_col in emg_signal_columns:
    print(f"\nValidation for {signal_col}:")
    
    if signal_col in filtered_features and signal_col in unfiltered_features:
        filtered = filtered_features[signal_col]
        unfiltered = unfiltered_features[signal_col]
        
        print("\nFeature comparison (filtered vs unfiltered):")
        for feature in filtered.keys():
            filtered_val = filtered[feature]
            unfiltered_val = unfiltered[feature]
            diff = abs(filtered_val - unfiltered_val)
            print(f"{feature}:")
            print(f"  Filtered: {filtered_val:.4f}")
            print(f"  Unfiltered: {unfiltered_val:.4f}") 
            print(f"  Difference: {diff:.4f}")
            
            # Flag large differences
            if diff > 0.5 * abs(filtered_val):
                print("  WARNING: Large difference between filtered and unfiltered values")
    
    # Visualize frequency content
    signal = emg_pitch_data[signal_col].values
    N = len(signal)
    T = 1.0 / fs
    x = np.linspace(0.0, N*T, N, endpoint=False)
    
    # FFT analysis
    yf = fft(signal)
    xf = fftfreq(N, T)[:N//2]
    
    plt.figure(figsize=(10,6))
    plt.plot(xf, 2.0/N * np.abs(yf[0:N//2]))
    plt.grid()
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Amplitude")
    plt.title(f"Frequency Spectrum - {signal_col}")
    plt.show()
    
    # Power spectral density comparison
    plot_psd_before_and_after(signal, fs=fs)
