In [1]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")

Mounted at /content/MyDrive


In [2]:
# @title Initialize Config

import torch
class Config:
    def __init__(self, **kwargs):
        self.batch_size = kwargs.get('batch_size', 64)
        self.epochs = kwargs.get('epochs', 50)
        self.lr = kwargs.get('lr', 0.001)
        self.input_format = kwargs.get('input_format', 'csv')
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_acc_test = kwargs.get('channels_imu_acc_test', [])
        self.channels_imu_gyr_test = kwargs.get('channels_imu_gyr_test', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.dataset_train_name = kwargs.get('dataset_train_name', 'train')
        self.dataset_test_name = kwargs.get('dataset_test_name', 'test')
        self.window_length = kwargs.get('window_length', 100)
        self.window_overlap = kwargs.get('window_overlap', 0)
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')
        self.train_subjects = kwargs.get('train_subjects', [])
        self.test_subjects = kwargs.get('test_subjects', [])

        self.dataset_name = self.generate_dataset_name()

    def generate_dataset_name(self):
        name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{self.train_subjects}_test{self.test_subjects}"
        return name

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection',
    dataset_root='/content/datasets',
    dataset_train_name='train',
    dataset_test_name='test',
    batch_size=64,
    epochs=150,
    lr=0.001,
    window_length=100,
    window_overlap=75,
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1', 'ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_acc_test=['ACCX1', 'ACCY1', 'ACCZ1', 'ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_imu_gyr_test=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
    train_subjects=['subject_2','subject_3','subject_4','subject_6','subject_7','subject_8', 'subject_9','subject_10', 'subject_11','subject_12','subject_13'],
    test_subjects=['subject_1']
)

In [3]:
# @title Pre Window and save for reuse
import os
import shutil
import numpy as np
import pandas as pd

#import twdm for collab
from tqdm.notebook import tqdm


class DataSharder:
    def __init__(self, config, save_h5=False):
        self.config = config
        self.input_format = config.input_format
        self.data_folder_path = config.data_folder_name
        self.window_length = int(config.window_length)
        self.window_overlap = int(config.window_overlap)
        self.save_h5 = save_h5

    def load_data(self):
        print(f"Training subjects: {self.config.train_subjects}")
        print(f"Testing subjects: {self.config.test_subjects}")

        if self.input_format == 'csv':
            self._process_and_save_patients_csv(self.config.train_subjects, "train")
            self._process_and_save_patients_csv(self.config.test_subjects, "test")
        else:
            raise ValueError(f"Unsupported input format: {self.input_format}")

    def _resample_data(self, data, sample_rate):
        if sample_rate != self.sample_rate:
            data = torch.nn.functional.interpolate(data.unsqueeze(0), size=self.sample_rate, mode='linear').squeeze(0)
        return data

    def _process_and_save_patients_csv(self, patient_id_list, split):
        column_names = None
        for patient_id in tqdm(patient_id_list, desc=f"Processing {split} patients"):
            combined_path = os.path.join(self.data_folder_path, patient_id, "combined")
            if not os.path.exists(combined_path):
                print(f"Directory {combined_path} does not exist. Skipping patient {patient_id}.")
                continue

            patient_files = os.listdir(combined_path)
            for session_file in tqdm(patient_files, desc=f"Processing sessions for {patient_id}", leave=False):
                data = pd.read_csv(os.path.join(combined_path, session_file))
                if column_names is None:
                    column_names = data.columns.tolist()  # Convert Index to list
                data_np = data.to_numpy()
                # Pad array to fit columns if necessary
                if data_np.shape[1] < len(column_names):
                    data_np = np.pad(data_np, ((0, 0), (0, len(column_names) - data_np.shape[1])), mode='constant')
                elif data_np.shape[1] > len(column_names):
                    # Extend column names to match the data shape
                    extra_columns = [f"extra_{i}" for i in range(data_np.shape[1] - len(column_names))]
                    column_names.extend(extra_columns)

                self._save_windowed_data(pd.DataFrame(data_np, columns=column_names), patient_id, session_file.split('.')[0], split, is_csv=True)

    def _save_windowed_data(self, data, patient_id, session_id, split, is_csv=False):
        dataset_folder = os.path.join(self.config.dataset_root, self.config.dataset_name, self.config.dataset_train_name if split == "train" else self.config.dataset_test_name)
        os.makedirs(dataset_folder, exist_ok=True)

        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        data_info_list = []

        for i in tqdm(range(0, len(data) - window_size + 1, step_size), desc=f"Windowing data for {patient_id}_{session_id}", leave=False):
            windowed_data = data.iloc[i:i+window_size] if is_csv else data[i:i+window_size]
            if windowed_data.shape[0] < window_size:
                continue

            windowed_data_np = windowed_data.to_numpy() if is_csv else windowed_data.cpu().numpy()

            file_name = f"{patient_id}_session_{session_id}_window_{i}_ws{window_size}_ol{overlap}.csv"
            file_path = os.path.join(dataset_folder, file_name)
            pd.DataFrame(windowed_data_np, columns=data.columns if is_csv else None).to_csv(file_path, index=False)
            data_info_list.append({"file_name": file_name, "file_path": file_path})

        data_info_df = pd.DataFrame(data_info_list)
        data_info_df.to_csv(os.path.join(self.config.dataset_root, self.config.dataset_name, f"{split}_info.csv"), index=False, mode='a', header=not os.path.exists(os.path.join(self.config.dataset_root, self.config.dataset_name, f"{split}_info.csv")))



reshard = True

if reshard:

  data_sharder = DataSharder(config)
  data_sharder.load_data()

  #copy new dataset to drive
  # shutil.copytree(os.path.join("/content/datasets",config.dataset_name),os.path.join("/content/MyDrive/MyDrive/datasets",config.dataset_name))

if not os.path.exists("/content/datasets"):
    #copy over
    shutil.copytree(os.path.join("/content/MyDrive/MyDrive/datasets",config.dataset_name),os.path.join("/content/datasets",config.dataset_name))


Training subjects: ['subject_2', 'subject_3', 'subject_4', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13']
Testing subjects: ['subject_1']


Processing train patients:   0%|          | 0/11 [00:00<?, ?it/s]

Processing sessions for subject_2:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_veryfast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_overheadreach_max_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_overheadreach_90_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_overheadreach_45_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elboxflexion_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elbowrotation_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elboxflexion_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elboxflexion_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_crossbody_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_crossbody_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elbowrotation_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elbowrotation_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Processing sessions for subject_3:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_overheadreach45_combined:   0%|          | 0/126 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_overheadreach90_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowflexion_fast_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowflexion_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowflexion_slow_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T002_elbowrotation_slow_combined:   0%|          | 0/129 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowrotation_fast_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_armswing_normal_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_armswing_veryfast_combined:   0%|          | 0/128 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_armswing_fast_combined:   0%|          | 0/129 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowrotation_normal_combined:   0%|          | 0/78 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T003_armswing_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_crossbody_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_crossbody_fast_combined:   0%|          | 0/128 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_overheadreachMax_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_crossbody_slow_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Processing sessions for subject_4:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_veryfast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_overheadreach_45_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowrotation_slow_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_crossbody_slow_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_crossbody_fast_combined:   0%|          | 0/142 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowflexion_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_overheadreach_90_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowrotation_fast_combined:   0%|          | 0/140 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_crossbody_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowrotation_normal_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_overheadreach_max_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowflexion_fast_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowflexion_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Processing sessions for subject_6:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_slow_combined:   0%|          | 0/112 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_veryfast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_overheadreach_max_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_overheadreach_45_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowflexion_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowflexion_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowflexion_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_crossbody_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowrotation_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_crossbody_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowrotation_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_overheadreach_90_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowrotation_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Processing sessions for subject_7:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T002_armSwing_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_armSwing_veryfast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_overheadreach_max_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_overheadreach_90_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowflexion_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowflexion_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowflexion_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowrotation_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowrotation_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowrotation_normal_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_overheadreach_45_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_crossbody_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_crossbody_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_armSwing_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_armSwing_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Processing sessions for subject_8:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowflexion_fast_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowrotation_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_veryfast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_overheadreach_max_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_overheadreach_45_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_fast_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_overheadreach_90_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowflexion_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_crossbody_fast_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowrotation_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowflexion_normal_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_crossbody_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowrotation_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Processing sessions for subject_9:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_normal_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_elbowflexion_normal_combined:   0%|          | 0/127 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_fast_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_overheadreach_90_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_veryfast_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_elbowflexion_slow_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_overheadreach_max_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowrotation_fast_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_overheadreach_45_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowflexion_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_crossbody_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_crossbody_normal_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowrotation_slow_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowrotation_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_crossbody_fast_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Processing sessions for subject_10:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_overheadreach_45_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_elbowflexion_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_elbowflexion_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_overheadreach_90_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T002_elbowrotation_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T002_elbowrotation_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_crossbody_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_overheadreach_max_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_crossbody_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T002_elbowrotation_normal_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_veryfast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_elbowflexion_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Processing sessions for subject_11:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_overheadreach_90_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_overheadreach_max_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_crossbody_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowflexion_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_crossbody_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowflexion_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowrotation_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_overheadreach_45_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_veryfast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowrotation_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowrotation_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowflexion_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Processing sessions for subject_12:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowrotation_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s…

Windowing data for subject_12_P0012_T001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_slow_combined:   0%|          | 0/126 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_overheadreach_90_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_veryfast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_fast_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_overheadreach_45_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowflexion_slow_combined:   0%|          | 0/128 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowrotation_fast_combined:   0%|          | 0/129 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowrotation_slow_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowflexion_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_overheadreach_max_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_crossbody_fast_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowflexion_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_crossbody_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Processing sessions for subject_13:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_overheadreach_max_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_veryfast_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_crossbody_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_crossbody_normal_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_overheadreach_90_combined:   0%|          | 0/119 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowrotation_slow_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowflexion_slow_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_slow_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowrotation_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_crossbody_fast_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowflexion_fast_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_overheadreach_45_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowrotation_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_fast_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowflexion_normal_combined:   0%|          | 0/120 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_normal_combined:   0%|          | 0/121 [00:00<?, ?it/s]

Processing test patients:   0%|          | 0/1 [00:00<?, ?it/s]

Processing sessions for subject_1:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_crossbody_fast_combined:   0%|          | 0/130 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowrotation_fast_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowflexion_fast_combined:   0%|          | 0/124 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_fast_combined:   0%|          | 0/126 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_veryfast_combined:   0%|          | 0/123 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_overheadreach_max_combined:   0%|          | 0/130 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowflexion_slow_combined:   0%|          | 0/128 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_overheadreach_90_combined:   0%|          | 0/130 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_slow_combined:   0%|          | 0/122 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowrotation_normal_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowrotation_slow_combined:   0%|          | 0/131 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_overheadreach_45_combined:   0%|          | 0/125 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowflexion_normal_combined:   0%|          | 0/127 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_crossbody_normal_combined:   0%|          | 0/131 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_crossbody_slow_combined:   0%|          | 0/129 [00:00<?, ?it/s]

In [4]:
from torch.utils.data import Dataset, DataLoader

class ImuJointPairDataset(Dataset):
    def __init__(self, config, split='train'):
        self.config = config
        self.split = split
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_acc_test = config.channels_imu_acc_test
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_imu_gyr_test = config.channels_imu_gyr_test
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        dataset_name = self.config.dataset_name
        self.root_dir_train = os.path.join(self.config.dataset_root, dataset_name, self.config.dataset_train_name)
        self.root_dir_test = os.path.join(self.config.dataset_root, dataset_name, self.config.dataset_test_name)

        train_info_path = os.path.join(self.config.dataset_root, dataset_name, "train_info.csv")
        test_info_path = os.path.join(self.config.dataset_root, dataset_name, "test_info.csv")
        self.data = pd.read_csv(train_info_path) if split == 'train' else pd.read_csv(test_info_path)



    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.split == "train":
            file_path = os.path.join(self.root_dir_train, self.data.iloc[idx, 0])
        else:
            file_path = os.path.join(self.root_dir_test, self.data.iloc[idx, 0])


        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        windows = self._apply_windowing(imu_data_acc, imu_data_gyr, joint_data, emg_data, self.config.window_length, self.config.window_overlap)

        acc_concat = np.concatenate([w[0] for w in windows], axis=0)
        gyr_concat = np.concatenate([w[1] for w in windows], axis=0)
        joint_concat = np.concatenate([w[2] for w in windows], axis=0)
        emg_concat = np.concatenate([w[3] for w in windows], axis=0)

        return acc_concat, gyr_concat, joint_concat, emg_concat

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        combined_data = np.concatenate([imu_data_acc, imu_data_gyr, joint_data, emg_data], axis=1)
        scaled_data = combined_data

        imu_data_acc = scaled_data[:, :imu_data_acc.shape[1]]
        imu_data_gyr = scaled_data[:, imu_data_acc.shape[1]:imu_data_acc.shape[1] + imu_data_gyr.shape[1]]
        joint_data = scaled_data[:, imu_data_acc.shape[1] + imu_data_gyr.shape[1]:imu_data_acc.shape[1] + imu_data_gyr.shape[1] + joint_data.shape[1]]
        emg_data = scaled_data[:, imu_data_acc.shape[1] + imu_data_gyr.shape[1] + joint_data.shape[1]:]

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        if isinstance(channels, slice):
            return combined_data.iloc[:, channels].values if self.input_format == "csv" else combined_data[:, channels]
        else:
            return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def _apply_windowing(self, imu_data_acc, imu_data_gyr, joint_data, emg_data, window_length, window_overlap):
        num_samples = imu_data_acc.shape[0]
        step = window_length - window_overlap
        windows = []

        for start in range(0, num_samples - window_length + 1, step):
            end = start + window_length
            window = (
                imu_data_acc[start:end],
                imu_data_gyr[start:end],
                joint_data[start:end],
                emg_data[start:end]
            )
            windows.append(window)

        return windows

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        data = torch.tensor(data, dtype=torch.float32)
        return data


# Create datasets
train_dataset = ImuJointPairDataset(config, split='train')
test_dataset = ImuJointPairDataset(config, split='test')

# Setup validation dataset
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [int(0.9 * len(train_dataset)), len(train_dataset) - int(0.9 * len(train_dataset))])

# Setup dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

acc,gyro,target,emg=next(iter(train_loader))
print(acc.shape)
print(gyro.shape)
print(target.shape)
print(emg.shape)

torch.Size([64, 100, 18])
torch.Size([64, 100, 18])
torch.Size([64, 100, 3])
torch.Size([64, 100, 3])


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np


class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss
#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7



############################################################################################################################################################################################################################################################################################################################################################################################################################################################################




In [17]:
import torch
import numpy as np
import os
import time
from tqdm.notebook import tqdm

import statistics

# Define the GHM-MSE Loss
class GHMMSELoss(torch.nn.Module):
    def __init__(self, bins=10, momentum=0.75):
        super(GHMMSELoss, self).__init__()
        self.bins = bins
        self.momentum = momentum
        self.edges = torch.linspace(0, 1, bins + 1)  # Initially on CPU
        self.acc_sum = torch.zeros(bins)  # Initially on CPU

    def forward(self, input, target):
        device = input.device  # Ensure we use the same device as input tensor
        self.edges = self.edges.to(device)  # Move edges to the same device
        self.acc_sum = self.acc_sum.to(device)  # Move acc_sum to the same device

        # Calculate the gradient norm (g)
        g = torch.sigmoid(2 * (input - target))
        g = g.detach()

        # Compute the gradient density
        n = input.size(0)
        inds = torch.bucketize(g, self.edges)
        total = torch.zeros_like(self.acc_sum)

        # Counting in each bin and calculate gradient harmonizing parameter (beta)
        for i in range(self.bins):
            total[i] = (inds == i + 1).float().sum()

        if self.training:
            self.acc_sum = self.momentum * self.acc_sum + (1 - self.momentum) * total
        acc_sum = self.acc_sum.clone().detach()

        # Beta calculation based on GD
        beta = torch.zeros_like(inds, dtype=torch.float32)
        for i in range(self.bins):
            beta[inds == i + 1] = n / (acc_sum[i] + 1e-6)

        loss = (beta * (input - target) ** 2).mean()
        return loss


# Define OHEM-MSE Loss
class OHEMMSELoss(torch.nn.Module):
    def __init__(self, ratio=0.7):
        super(OHEMMSELoss, self).__init__()
        self.ratio = ratio
        self.mse_loss = torch.nn.MSELoss(reduction='none')

    def forward(self, input, target):
        losses = self.mse_loss(input, target)
        num_hard_examples = int(self.ratio * losses.numel())
        hard_losses, _ = torch.topk(losses.view(-1), num_hard_examples)
        return hard_losses.mean()

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(loader):
            output, _ = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())
            loss = criterion(output, target.to(device).float())
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse

# Save checkpoint function
def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None, channelwise_metrics=None):
    """Saves the model, optimizer state, and losses (including channel-wise) to a checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }
    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")

# Training function
def train_teacher(device, train_loader, val_loader, test_loader, learn_rate, epochs, model, filename,loss_function):

    model.to(device)
    criterion = loss_function
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()

        epoch_train_loss = np.zeros(len(config.channels_joints))
        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()
            output, _ = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())
            loss = criterion(output, target.to(device).float())
            loss.backward()
            optimizer.step()

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            epoch_train_loss += loss.detach().cpu().numpy()
            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_loss)
        train_pccs.append(avg_train_pcc)
        train_rmses.append(avg_train_rmse)

        # Evaluate on validation set every epoch
        avg_val_loss, avg_val_pcc, avg_val_rmse = evaluate_model(device, model, val_loader, criterion)
        val_losses.append(avg_val_loss)
        val_pccs.append(avg_val_pcc)
        val_rmses.append(avg_val_rmse)

        # Evaluate on test set and checkpoint every epoch
        avg_test_loss, avg_test_pcc, avg_test_rmse = evaluate_model(device, model, test_loader, criterion)
        test_losses.append(avg_test_loss)
        test_pccs.append(avg_test_pcc)
        test_rmses.append(avg_test_rmse)

        print(f"Epoch: {epoch + 1}, Training Loss: {np.mean(avg_train_loss)}, Validation Loss: {np.mean(avg_val_loss):.4f}", f"Test Loss: {np.mean(avg_test_loss):.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse)}, Validation RMSE: {np.mean(avg_val_rmse):.4f}", f"Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc)}, Validation PCC: {np.mean(avg_val_pcc):.4f}", f"Test PCC: {np.mean(avg_test_pcc):.4f}")

        if not os.path.exists(f"/content/MyDrive/MyDrive/models/{filename}"):
            os.makedirs(f"/content/MyDrive/MyDrive/models/{filename}")

        # Save checkpoint, including channel-wise metrics
        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"/content/MyDrive/MyDrive/models/{filename}/{filename}_epoch_{epoch+1}.pth",
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            test_loss=avg_test_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            }
        )

        # Early stopping logic
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses


class CrossAttention(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CrossAttention, self).__init__()
        self.query = nn.Linear(output_dim, output_dim)  # Query: Joint angles or features from state-space model
        self.key = nn.Linear(input_dim, output_dim)     # Key: Sensor data (IMU/EMG)
        self.value = nn.Linear(input_dim, output_dim)   # Value: Sensor data (IMU/EMG)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, state_features, sensor_data):
        # state_features: Predicted joint angles (batch_size, output_dim)
        # sensor_data: IMU/EMG data (batch_size, num_channels, sequence_length)

        # Reshape sensor_data to 2D for Linear layer: [batch_size, num_channels * sequence_length]
        batch_size, num_channels, sequence_length = sensor_data.shape
        sensor_data_flat = sensor_data.reshape(batch_size, -1)  # Flatten last two dimensions

        query = self.query(state_features).unsqueeze(1)  # (batch_size, 1, output_dim)
        key = self.key(sensor_data_flat).unsqueeze(1)    # (batch_size, 1, output_dim)
        value = self.value(sensor_data_flat).unsqueeze(1)  # (batch_size, 1, output_dim)

        # Calculate attention scores (correlation between query and key)
        attention_scores = torch.bmm(query, key.transpose(1, 2)) / torch.sqrt(torch.tensor(key.size(-1), dtype=torch.float))
        attention_weights = self.softmax(attention_scores)  # (batch_size, 1, num_sensors)

        # Weighted sum of the values (sensor data)
        attention_output = torch.bmm(attention_weights, value)  # (batch_size, 1, output_dim)
        attention_output = attention_output.squeeze(1)  # Remove singleton dimension

        return attention_output, attention_weights




# Convolutional Model with Attention-Determined Filter Size
class ConvFeatureExtractorWithAttention(nn.Module):
    def __init__(self, input_size, num_filters, output_size, state_output_dim):
        super(ConvFeatureExtractorWithAttention, self).__init__()
        self.attention = CrossAttention(input_dim=input_size, output_dim=state_output_dim)
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(in_channels=input_size, out_channels=num_filters, kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_channels=input_size, out_channels=num_filters, kernel_size=5, stride=1, padding=2),
            nn.Conv1d(in_channels=input_size, out_channels=num_filters, kernel_size=7, stride=1, padding=3),
        ])
        self.fc = nn.Linear(num_filters, output_size)

    def forward(self, state_features, sensor_data):
        # Get the attention output and attention weights
        attention_output, attention_weights = self.attention(state_features, sensor_data)

        # Use the attention weights to determine which convolutional filter to apply
        # Example: If attention_weight for sensor modality is higher, we use the larger filter size
        conv_output = 0
        for i, conv_layer in enumerate(self.conv_layers):
            filter_weight = attention_weights[:, :, i].unsqueeze(2)  # Weight for the current filter
            conv_out = F.relu(conv_layer(sensor_data))  # Apply convolution
            conv_output += filter_weight * conv_out  # Weighted sum based on attention

        # Pooling and fully connected layer to output the feature vector
        conv_output = torch.mean(conv_output, dim=-1)  # Pooling across the time dimension
        output = self.fc(conv_output)

        return output, attention_weights

# State-Space Model that uses convolutional features with attention
class StateSpaceModel(nn.Module):
    def __init__(self, feature_size, hidden_size, output_size, Ts, alpha=0.5):
        super(StateSpaceModel, self).__init__()
        self.fc1 = nn.Linear(feature_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.Ts = Ts
        self.alpha = alpha

    def state_transition(self, theta, theta_dot, theta_ddot, Ts):
        new_theta = theta + Ts * theta_dot + 0.5 * (Ts ** 2) * theta_ddot
        new_theta_dot = theta_dot + Ts * theta_ddot
        new_theta_ddot = self.alpha * theta_ddot
        return new_theta, new_theta_dot, new_theta_ddot

    def forward(self, features):
        h = torch.relu(self.fc1(features))
        out = self.fc2(h)
        return out

# Cross-Attention Convolutional Model with Dynamic Input Sizes
class CrossAttentionConvStateSpaceModel(nn.Module):
    def __init__(self, imu_acc_input_size, imu_gyr_input_size, emg_input_size, num_filters, output_size, state_output_dim, hidden_size, Ts):
        super(CrossAttentionConvStateSpaceModel, self).__init__()

        # Calculate total input size based on all sensor channels
        total_input_size = imu_acc_input_size + imu_gyr_input_size + emg_input_size

        # Update the convolutional model to reflect the correct input sizes
        self.conv_feature_extractor = ConvFeatureExtractorWithAttention(
            input_size=total_input_size,
            num_filters=num_filters,
            output_size=state_output_dim,
            state_output_dim=state_output_dim
        )

        # The state-space model's input size is now dynamically computed
        self.state_space_model = StateSpaceModel(
            feature_size=state_output_dim * total_input_size,  # Adjust for feature size from the convolution
            hidden_size=hidden_size,
            output_size=output_size,
            Ts=Ts
        )

    def forward(self, imu_acc_data, imu_gyr_data, emg_data):
        # Concatenate IMU accelerometer, gyroscope, and EMG data along the feature dimension
        sensor_data = torch.cat([imu_acc_data, imu_gyr_data, emg_data], dim=2)  # Concatenate along the feature dimension (dim=2)

        # Transpose sensor_data to match the expected input for Conv1d: [batch_size, num_channels, sequence_length]
        sensor_data = sensor_data.transpose(1, 2)  # Change from [batch_size, sequence_length, num_channels] to [batch_size, num_channels, sequence_length]

        # Initialize state features (joint angles) with the correct dimension size for the attention mechanism
        state_features = torch.zeros(imu_acc_data.shape[0], 64, device=imu_acc_data.device)  # Batch size x 64 (matching the convolution output)

        # Extract features with convolutional filtering and attention
        features, _ = self.conv_feature_extractor(state_features, sensor_data)

        # Print the shape of features before flattening
        print(f"Shape of features before flattening: {features.shape}")

        # Flatten or reshape the features for the fully connected layers
        features = features.view(features.size(0), -1)  # Flatten to [batch_size, num_features]

        # Print the shape of features after flattening
        print(f"Shape of features after flattening: {features.shape}")

        # Pass features into the state-space model for joint angle prediction
        predicted_angles = self.state_space_model(features)

        return predicted_angles, features






# Example to dynamically create the model with the number of channels
imu_acc_channels = 18  # Acceleration channels
imu_gyr_channels = 18  # Gyroscope channels
emg_channels = 3       # EMG channels

# Update the model_configs to use this new model with dynamic input sizes
model_configs = {
    'CrossAttentionConvStateSpaceModel_RMSELoss': {
        'model': CrossAttentionConvStateSpaceModel(
            imu_acc_input_size=imu_acc_channels,
            imu_gyr_input_size=imu_gyr_channels,
            emg_input_size=emg_channels,
            num_filters=64,
            output_size=3,  # Joint angles
            state_output_dim=64,
            hidden_size=128,
            Ts=1/128  # Sampling period
        ),
        'loss': RMSELoss()  # Use RMSELoss for this model
    },
    'CrossAttentionConvStateSpaceModel_GHMMSELoss': {
        'model': CrossAttentionConvStateSpaceModel(
            imu_acc_input_size=imu_acc_channels,
            imu_gyr_input_size=imu_gyr_channels,
            emg_input_size=emg_channels,
            num_filters=64,
            output_size=3,  # Joint angles
            state_output_dim=64,
            hidden_size=128,
            Ts=1/128  # Sampling period
        ),
        'loss': GHMMSELoss()  # Use GHM-MSE Loss for this model
    },
    'CrossAttentionConvStateSpaceModel_OHMELoss': {
        'model': CrossAttentionConvStateSpaceModel(
            imu_acc_input_size=imu_acc_channels,
            imu_gyr_input_size=imu_gyr_channels,
            emg_input_size=emg_channels,
            num_filters=64,
            output_size=3,  # Joint angles
            state_output_dim=64,
            hidden_size=128,
            Ts=1/128  # Sampling period
        ),
        'loss': OHEMMSELoss()  # Use OHEM-MSE Loss for this model
    }
}
# Train the model using your existing training loop
config.epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loop over each model and run the training with the corresponding loss function
for model_name, model_config in model_configs.items():
    model = model_config['model']
    loss_function = model_config['loss']  # Get the corresponding loss function

    print(f"Starting training for {model_name} using {loss_function.__class__.__name__}...")

    # Train the model using the common loaders, learning rate, and epochs from config
    model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses = train_teacher(
        device,
        train_loader,
        val_loader,
        test_loader,
        config.lr,  # Learning rate from config
        config.epochs,  # Epochs from config
        model,
        model_name,  # Save checkpoint file named after the model
        loss_function  # Pass the specific loss function
    )

    print(f"Finished training for {model_name}.")


Starting training for CrossAttentionConvStateSpaceModel_RMSELoss using RMSELoss...


Epoch 1/100 Training:   0%|          | 0/299 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x3900 and 39x64)