In [None]:
import os
import shutil
import logging
import numpy as np
from tqdm import tqdm
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from collections import defaultdict
import warnings
import re
from concurrent.futures import ProcessPoolExecutor
import os.path as osp

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

def extract_time_range_from_trc(trc_file):
    with open(trc_file, 'r') as file:
        lines = file.readlines()
    num_frames = int(lines[2].strip().split()[2])
    print(f"Number of frames: {num_frames} for file: {trc_file}")
    data_rate = float(lines[2].strip().split()[0])
    time_end = num_frames / data_rate
    return (0, time_end)

def create_ik_setup_file(template_path, trc_file, output_motion_file, ik_setup_file):
    tree = ET.parse(template_path)
    root = tree.getroot()

    time_range = extract_time_range_from_trc(trc_file)
    time_range_str = f"{time_range[0]} {time_range[1]}"

    time_range_element = root.find('.//time_range')
    time_range_element.text = time_range_str

    marker_file = root.find('.//marker_file')
    marker_file.text = trc_file

    output_motion = root.find('.//output_motion_file')
    output_motion.text = output_motion_file

    tree.write(ik_setup_file)

def clear_file(file_path):
    with open(file_path, 'w') as file:
        file.truncate(0)

def run_inverse_kinematics(scaled_model_file, trc_file, ik_setup_file, output_dir):
    import opensim as osim

    log_file = './opensim.log'

    try:
        if not os.path.exists(scaled_model_file):
            raise FileNotFoundError(f"Scaled model file not found: {scaled_model_file}")
        if not os.path.exists(trc_file):
            raise FileNotFoundError(f"TRC file not found: {trc_file}")
        if not os.path.exists(ik_setup_file):
            raise FileNotFoundError(f"IK setup file not found: {ik_setup_file}")

        output_file = os.path.join(output_dir, os.path.basename(trc_file).replace('.trc', '.mot'))

        print(f"Output file: {output_file}")
        if os.path.exists(output_file):
            logging.info(f"Output file already exists, skipping: {output_file}")
            print(f"\033[93mOutput file already exists, skipping: {output_file}\033[0m")
            return

        model = osim.Model(scaled_model_file)
        ik_tool = osim.InverseKinematicsTool(ik_setup_file)
        ik_tool.setModel(model)
        ik_tool.setMarkerDataFileName(trc_file)
        ik_tool.setOutputMotionFileName(output_file)

        logging.debug(f"Running IK Tool with model: {scaled_model_file}, TRC: {trc_file}, Setup: {ik_setup_file}, Output: {output_file}")

        ik_tool.run()

        if os.path.exists(log_file):
            log_file_new = os.path.join(output_dir, os.path.basename(trc_file).replace('.trc', '_ik.log'))
            shutil.copy2(log_file, log_file_new)
            clear_file(log_file)

        print(f"\033[92mProcessed {trc_file} successfully\033[0m")
    except Exception as e:
        print(f"\033[91mError processing {trc_file}: {e}\033[0m")

def process_subject_trials(scaled_model_file, trc_files, ik_setup_template, output_dir, parallel=False):
    ik_setup_files = []
    for trc_file in trc_files:
        ik_setup_file = os.path.join(output_dir, os.path.basename(trc_file).replace('.trc', '_ik_setup.xml'))
        create_ik_setup_file(ik_setup_template, trc_file, os.path.join(output_dir, os.path.basename(trc_file).replace('.trc', '.mot')), ik_setup_file)
        ik_setup_files.append(ik_setup_file)

    if parallel:
        with ProcessPoolExecutor() as executor:
            futures = [
                executor.submit(run_inverse_kinematics, scaled_model_file, trc_file, ik_setup_file, output_dir)
                for trc_file, ik_setup_file in zip(trc_files, ik_setup_files)
            ]
            for future in tqdm(futures, desc="Processing trials in parallel"):
                future.result()
    else:
        for trc_file, ik_setup_file in tqdm(zip(trc_files, ik_setup_files), desc="Processing trials sequentially"):
            run_inverse_kinematics(scaled_model_file, trc_file, ik_setup_file, output_dir)

def get_log_files_by_subject(base_dir, subjects):
    subject_logs = {}
    for subject in subjects:
        output_dir = os.path.join(base_dir, f'P{str(subject).zfill(2)}', 'processed_joint_kinematics')
        if not os.path.exists(output_dir):
            print(f"\033[91mWarning: Output directory not found for subject {subject}. Skipping.\033[0m")
            continue
        log_files = [
            os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith('_ik.log')
        ]
        subject_logs[subject] = log_files
    return subject_logs

def process_all_subjects(base_dir, subjects, ik_setup_template, parallel=False):
    for subject in subjects:
        subject_formatted = str(subject).zfill(2)

        scaled_model_file = osp.join(base_dir, f'subject_{subject}.osim')
        scaled_model_file = osp.expanduser(scaled_model_file)  # Expanded here

        trc_files_dir = osp.join(base_dir, f'P{subject_formatted}', 'raw_marker')
        trc_files_dir = osp.expanduser(trc_files_dir)  # Expanded here

        output_dir = osp.join(base_dir, f'P{subject_formatted}', 'processed_joint_kinematics')
        output_dir = osp.expanduser(output_dir)  # Expanded here

        if not os.path.exists(trc_files_dir):
            print(f"\033[91mError: TRC files directory not found: {trc_files_dir}\033[0m")
            continue

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        trc_files = [os.path.join(trc_files_dir, f) for f in os.listdir(trc_files_dir) if f.endswith('.trc')]

        if not trc_files:
            print(f"\033[91mError: No TRC files found in directory: {trc_files_dir}\033[0m")
            continue

        process_subject_trials(scaled_model_file, trc_files, ik_setup_template, output_dir, parallel=parallel)

        # Plot subject data after processing each subject
        subject_logs = get_log_files_by_subject(base_dir, [subject])
        if subject_logs:
            plot_rms_by_subject(subject_logs)

def plot_rms_by_subject(subject_logs):
    all_rms_values = []
    n_cols = 8

    for subject, logs in subject_logs.items():
        n_trials = len(logs)
        n_rows = (n_trials + n_cols - 1) // n_cols

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(45, 5 * n_rows))
        axes = axes.flatten()

        for i, log_file in enumerate(logs):
            rms_values, max_errors = extract_rms_from_log(log_file)
            axes[i].plot(rms_values)
            axes[i].set_title(os.path.basename(log_file))
            axes[i].set_xlabel('Frame')
            axes[i].set_ylabel('RMS Error')
            axes[i].set_ylim([0, 0.10])
            for j, rms in enumerate(rms_values):
                if rms > 0.5:
                    axes[i].plot(j, rms, 'ro')

            all_rms_values.append((os.path.basename(log_file), rms_values))

            for marker, errors in max_errors.items():
                average_error = sum(errors) / len(errors)
                print(f"File: {os.path.basename(log_file)}, Marker: {marker}, Count: {len(errors)}, Average Error: {average_error:.6f}")

        for i in range(n_trials, n_rows * n_cols):
            fig.delaxes(axes[i])

        fig.suptitle(f'Subject {subject}')
        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        plt.show()

def extract_rms_from_log(log_file_path):
    rms_values = []
    max_errors = defaultdict(list)
    with open(log_file_path, 'r') as file:
        for line in file:
            if 'marker error: RMS' in line:
                rms_match = re.search(r'RMS = ([\d.]+)', line)
                max_match = re.search(r'max = ([\d.]+) \((\w+)\)', line)
                if rms_match and max_match:
                    rms_values.append(float(rms_match.group(1)))
                    max_error_value = float(max_match.group(1))
                    max_error_marker = max_match.group(2)
                    max_errors[max_error_marker].append(max_error_value)
    return rms_values, max_errors

if __name__ == "__main__":
    # Expand base_dir to ensure '~' is properly expanded
    base_dir = os.path.expanduser('~/GoogleDrive/sd_datacollection_v4')
    subjects = [1]
    ik_setup_template = os.path.expanduser('~/GoogleDrive/sd_datacollection_v4/default_ik.xml')

    parallel = True  # Set to False to run sequentially
    process_all_subjects(base_dir, subjects, ik_setup_template, parallel=parallel)

Processing trials in parallel:   0%|          | 0/16 [00:00<?, ?it/s]

Output file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_OR_90.mot
Output file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_ER_S.mot
Output file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_AS_VF.motOutput file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_CB_F.motOutput file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_EF_F.motOutput file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_CB_N.mot



Output file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_ER_F.mot
Output file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_EF_S.motOutput file: /home/oliver/GoogleDrive/sd_datacollection_v4/P12/processed_joint_kinematics/P12_T1_ER_N.mot

Output file: /home/oliver/GoogleDrive/sd_dat