# Preprocess movement data from Hands and Pose model

## 1) Imports

In [None]:
# libraries
import os

import numpy as np
import pandas as pd
import toml
import tqdm

# modules
from src import data_logging as data_log
import motion_processing
import render_landmarks

## 2) Initialize Logger and Prepare Movement Data Files

In [None]:
# Initialize logger
logger = data_log.setup_batch_logger(log_file_path='batch_preprocessing_logs.txt')

# read config data
with open('config.toml', 'r') as f:
    config = toml.load(f)

# framerate
FRAMERATE: float = config['camera_params']['framerate']

# define relevant paths
prepro_src_path: str = config['batch_preprocessing']['prepro_src_path']
prepro_out_path: str = config['batch_preprocessing']['prepro_out_path']

# participant info on affected side
affected_side_lst: list[list] = config['participant_info']['affected_side']
affected_side_dict: dict = {item[0]: item[1:][0] for item in affected_side_lst}

## 3) Batch Preprocess Data

In [None]:
# threshold max nan gap (1/5 s)
MAX_NAN_GAP: int = int(FRAMERATE) // 5

# find all motion files ('.csv')
motion_data_file_lst: list[str] = [os.path.join(path, file)
                                   for path, dirs, files in os.walk(prepro_src_path)
                                   for file in files if file.endswith('.csv') and not file.startswith('.')]

motion_data_file_lst = sorted(motion_data_file_lst)

# process each motion file
print(f'Starting preprocessing of {len(motion_data_file_lst)} movement files...')
for motion_data_fpath in tqdm.tqdm(motion_data_file_lst):

    # initialize the max consecutive NaN count for each motion data file
    max_nan_cnt: int = 0

    try:
        # get the participant id, affected side ('L' or 'R'), and trial id
        participant_id: str = os.path.basename(motion_data_fpath).split('_')[1]
        curr_affected_side: str = affected_side_dict[participant_id]
        trial_id: int = int(os.path.basename(motion_data_fpath).split('_')[-3][-2:])
        model_type: str = (os.path.basename(motion_data_fpath).split('_')[-1]).split('.')[0]

    except Exception as e:
        logger.error(f"Failed to parse metadata from {motion_data_fpath}: {e}")
        continue

    # processed file name
    new_file_name_path: str = motion_data_fpath.split('.')[0] + '_processed.csv'
    new_file_out_path: str = os.path.join(prepro_out_path, '/'.join(new_file_name_path.rsplit('/', maxsplit=3)[-3:]))

    # prevent overwriting processed files
    if os.path.exists(new_file_out_path):
        print(f'Skipping {os.path.basename(motion_data_fpath)}: Processed file already exists.')
        continue

    try:
        # read file
        motion_df: pd.DataFrame = pd.read_csv(motion_data_fpath)

        # extract arms or hands of focus for preprocessing
        if model_type == 'hands':
            # select hand of focus
            side_of_focus_df, side_of_focus_id = motion_processing.extract_hand_of_focus(motion_df, curr_affected_side, trial_id)

        elif model_type == 'pose':
            # select pose of focus
            side_of_focus_df, side_of_focus_id = motion_processing.extract_pose_of_focus(motion_df, curr_affected_side, trial_id)

        else:
            raise ValueError(f'Unknown model type: {model_type}')

        # preprocess data and dump filtered motion
        processed_motion_df, nan_cnt = motion_processing.preprocess_motion_data(df=side_of_focus_df, framerate=FRAMERATE)
        processed_motion_df.to_csv(new_file_out_path, index=False)

        # Check if the dataframe contains any valid numerical data
        if processed_motion_df.select_dtypes(include=np.number).empty or processed_motion_df.select_dtypes(include=np.number).dropna(how='all').empty:
            raise ValueError('Processed DataFrame contains no numerical data (all NaNs). Skipping plotting image and video.')

    except Exception as e:
        logger.error(f'Processing failed for {motion_data_fpath}. Error: {e}')
        nan_cnt = 9001  # indicate error in the warning log
        side_of_focus_id = ''


    # log INFO line
    info_log_data = {
        'PID': participant_id,
        'AFCT_SIDE': curr_affected_side,
        'FCS_SIDE': side_of_focus_id,
        'TID': trial_id,
        'FPATH': motion_data_fpath,
    }
    logger.info('File processing finished.', extra=info_log_data)

    # log WARNING line
    if nan_cnt > MAX_NAN_GAP:
        warning_log_data = {
            'REP_NAN': nan_cnt,
        }
        logger.warning('Max consecutive NaNs exceeded threshold.', extra=warning_log_data)

print('\n --------- All preprocessing complete! ---------')

## 4) Batch Render Videos with Hand/Pose Overlay

In [None]:
# find all preprocessed hands files ('.csv')
hands_data_file_lst: list[str] = [os.path.join(path, file)
                                  for path, dirs, files in os.walk(prepro_out_path)
                                  for file in files if file.endswith('hands.csv') and not file.startswith('.')]

hands_data_file_lst = sorted(hands_data_file_lst)

# find all preprocessed pose files ('.csv')
pose_data_file_lst: list[str] = [os.path.join(path, file)
                                 for path, dirs, files in os.walk(prepro_out_path)
                                 for file in files if file.endswith('pose.csv') and not file.startswith('.')]

pose_data_file_lst = sorted(pose_data_file_lst)

print('Starting batch rendering...')
for idx in tqdm.tqdm(range(len(hands_data_file_lst))):

    hands_base_path: str = '_'.join(os.path.basename(hands_data_file_lst[idx]).split('_')[:-2])
    pose_base_path: str = '_'.join(os.path.basename(pose_data_file_lst[idx]).split('_')[:-2])

    if hands_base_path == pose_base_path:

        # define paths
        new_video_path: str = hands_base_path + '_focused_overlay.mp4'
        new_video_out_path = os.path.join(prepro_out_path, os.path.basename(new_video_path))
        vid_basename: str = '_'.join(os.path.basename(hands_data_file_lst[idx]).split('_')[:-2]) + '.mp4'
        orig_video_path: str = os.path.join(prepro_src_path, vid_basename)

        # read landmark data
        hand_df: pd.DataFrame = pd.read_csv(hands_data_file_lst[idx])
        pose_df: pd.DataFrame = pd.read_csv(pose_data_file_lst[idx])

        # render video overlay
        render_landmarks.render_focused_hand_and_arm(orig_video_path, hand_df, pose_df, new_video_out_path)

print('\n --------- All rendering complete! ---------')