In [1]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Drive already mounted at /content/MyDrive; to attempt to forcibly remount, call drive.mount("/content/MyDrive", force_remount=True).


In [2]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data_final.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [3]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
      extracted_data = []
      new_column_names = []  # Initialize here

      if isinstance(session_data, h5py.Dataset):
          if session_data.dtype.names:
              # Compound dataset
              column_names = session_data.dtype.names
              for channel in channels:
                  if channel in column_names:
                      channel_data = session_data[channel][:]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)  # Populate here
                  else:
                      print(f"Channel {channel} not found in compound dataset.")
          else:
              # Simple dataset
              column_names = list(session_data.attrs.get('column_names', []))
              assert len(column_names) > 0, "column_names not found in dataset attributes"
              for channel in channels:
                  if channel in column_names:
                      col_idx = column_names.index(channel)
                      channel_data = session_data[:, col_idx]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)
                  else:
                      print(f"Channel {channel} not found in session data.")

      return np.array(extracted_data), new_column_names


In [4]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

class ImuJointPairSubjectDataset(ImuJointPairDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Create a mapping from subject strings (e.g., 'subject_1') to class indices
        self.subject_mapping = {subject: i for i, subject in enumerate(sorted(subjects))}

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx, 0]

        # Extract the subject ID from the filename
        filename_base = os.path.basename(filename)
        filename_without_ext = os.path.splitext(filename_base)[0]
        parts = filename_without_ext.split('_')

        # Construct subject string in the format 'subject_x'
        try:
            subject_index = parts.index('subject')
            subject_str = f"subject_{parts[subject_index + 1]}"  # Create string like 'subject_1'
        except ValueError:
            raise ValueError(f"'subject' not found in filename: {filename}")
        except IndexError:
            raise ValueError(f"Subject ID not found after 'subject' in filename: {filename}")

        # Map subject_str to class index
        if subject_str not in self.subject_mapping:
            raise ValueError(f"Subject ID {subject_str} not found in training set.")

        mapped_class = self.subject_mapping[subject_str]

        # Return class index instead of one-hot encoding
        return imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class


def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairSubjectDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairSubjectDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=0,  # Typically no overlap in test set
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


In [5]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used
class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100):
        super(teacher, self).__init__()

        self.w = w
        self.encoder_1_acc = Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr = Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg = Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc = Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr = Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg = Encoder_2(input_emg, drop_prob)

        self.BN_acc = nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr = nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg = nn.BatchNorm1d(input_emg, affine=False)

        self.fc = nn.Linear(2 * 3 * 128 + 128, 3)
        self.dropout = nn.Dropout(p=0.05)

        self.gate_1 = GatingModule(128)
        self.gate_2 = GatingModule(128)
        self.gate_3 = GatingModule(128)

        self.fc_kd = nn.Linear(3 * 128, 2 * 128)

        # Gating and attention networks
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.attention = nn.MultiheadAttention(3 * 128, 4, batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128 * 3, 3 * 128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2 * 3 * 128 + 128, 2 * 3 * 128 + 128), nn.Sigmoid())

        # Pooling for embeddings
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.embedding_fc = nn.Linear(2 * 3 * 128 + 128, 128)  # Embedding layer for prototypical learning

    def forward(self, x_acc, x_gyr, x_emg):
        # Preprocess inputs with batch normalization
        x_acc_1 = x_acc.view(x_acc.size(0) * x_acc.size(1), x_acc.size(-1))
        x_gyr_1 = x_gyr.view(x_gyr.size(0) * x_gyr.size(1), x_gyr.size(-1))
        x_emg_1 = x_emg.view(x_emg.size(0) * x_emg.size(1), x_emg.size(-1))

        x_acc_1 = self.BN_acc(x_acc_1)
        x_gyr_1 = self.BN_gyr(x_gyr_1)
        x_emg_1 = self.BN_emg(x_emg_1)

        x_acc_2 = x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2 = x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2 = x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Encoder passes
        x_acc_1, (h_acc_1, _) = self.encoder_1_acc(x_acc_2)
        x_gyr_1, (h_gyr_1, _) = self.encoder_1_gyr(x_gyr_2)
        x_emg_1, (h_emg_1, _) = self.encoder_1_emg(x_emg_2)

        x_acc_2, (h_acc_2, _) = self.encoder_2_acc(x_acc_2)
        x_gyr_2, (h_gyr_2, _) = self.encoder_2_gyr(x_gyr_2)
        x_emg_2, (h_emg_2, _) = self.encoder_2_emg(x_emg_2)

        # Gating layers
        x_acc = self.gate_1(x_acc_1, x_acc_2)
        x_gyr = self.gate_2(x_gyr_1, x_gyr_2)
        x_emg = self.gate_3(x_emg_1, x_emg_2)

        x = torch.cat((x_acc, x_gyr, x_emg), dim=-1)
        x_kd = self.fc_kd(x)

        # Attention and feature gating
        out_1, _ = self.attention(x, x, x)
        gating_weights = self.gating_net(x)
        out_2 = gating_weights * x

        # Weighted features
        weights_1 = self.weighted_feat(x[:, :, 0:128])
        weights_2 = self.weighted_feat(x[:, :, 128:2 * 128])
        weights_3 = self.weighted_feat(x[:, :, 2 * 128:3 * 128])
        x_1 = weights_1 * x[:, :, 0:128]
        x_2 = weights_2 * x[:, :, 128:2 * 128]
        x_3 = weights_3 * x[:, :, 2 * 128:3 * 128]
        out_3 = x_1 + x_2 + x_3

        # Gating and final output
        out = torch.cat((out_1, out_2, out_3), dim=-1)
        gating_weights_1 = self.gating_net_1(out)
        out = gating_weights_1 * out
        output = self.fc(out)

        # Prototypical embedding: Average across time dimension for fixed-length embeddings
        embeddings = self.embedding_fc(out.mean(dim=1))

        return output, embeddings, x_kd





In [6]:
# @title Loss Functions
import statistics

class PrototypicalLoss(nn.Module):
    def forward(self, embeddings, labels):
        # Get unique classes and calculate class prototypes
        unique_labels = torch.unique(labels)
        prototypes = torch.stack([embeddings[labels == label].mean(0) for label in unique_labels])

        # Calculate distance of each sample embedding to each prototype
        distances = torch.cdist(embeddings, prototypes)  # Shape: (batch_size, num_classes)

        # Ensure labels are correctly offset
        label_map = {label.item(): i for i, label in enumerate(unique_labels)}
        mapped_labels = torch.tensor([label_map[label.item()] for label in labels], device=distances.device)

        # Calculate prototypical loss
        loss = F.log_softmax(-distances, dim=1).gather(1, mapped_labels.view(-1, 1)).mean()
        return -loss  # Maximize log probability of the correct prototype


class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

In [7]:





# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG,_) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse




def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(
    device, train_loader, val_loader, test_loader, learn_rate, epochs, model, filename, loss_function,
    proto_loss_function, optimizer=None, l1_lambda=None, train_from_last_epoch=False, curriculum_loader=None
):
    model.to(device)
    criterion = loss_function
    proto_criterion = proto_loss_function  # Prototypical loss function

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)


    # Metrics tracking
    train_losses, val_losses, test_losses = [], [], []
    train_proto_losses, val_proto_losses, test_proto_losses = [], [], []
    train_pccs, val_pccs, test_pccs = [], [], []
    train_rmses, val_rmses, test_rmses = [], [], []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    # Load checkpoint if resuming from last epoch
    if train_from_last_epoch and os.path.exists(checkpoint_path):
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']
            # Load history
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    best_val_loss = float('inf')
    patience_counter = 0
    patience = 10
    start_time = time.time()
    for epoch in range(last_epoch, epochs):
        model.train()
        epoch_train_loss = 0.0
        epoch_train_proto_loss = 0.0

        for i, (data_acc, data_gyr, target, data_EMG, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            output = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())
            if isinstance(model, teacher):
                output, embeddings, _ = output  # Get embeddings for prototypical loss
                primary_loss = criterion(output, target.to(device).float())
                proto_loss = proto_criterion(embeddings, labels.to(device))  # Calculate prototypical loss
            else:
                primary_loss = criterion(output, target.to(device).float())
                proto_loss = 0  # No prototypical loss for non-teacher models

            # Total loss combines primary and prototypical loss
            total_loss = primary_loss + proto_loss
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss += l1_lambda * l1_norm

            # Backpropagate and optimize
            total_loss.backward()
            optimizer.step()

            epoch_train_loss += primary_loss.item()
            epoch_train_proto_loss += proto_loss.item()

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_proto_loss = epoch_train_proto_loss / len(train_loader)

        train_losses.append(avg_train_loss)
        train_proto_losses.append(avg_train_proto_loss)

        # Validation
        avg_val_loss, avg_val_pcc, avg_val_rmse = evaluate_model(
            device, model, val_loader, criterion
        )
        val_losses.append(avg_val_loss)
        val_pccs.append(avg_val_pcc)
        val_rmses.append(avg_val_rmse)

        # Test evaluation
        avg_test_loss, avg_test_pcc, avg_test_rmse = evaluate_model(
            device, model, test_loader, criterion
        )
        test_losses.append(avg_test_loss)
        test_pccs.append(avg_test_pcc)
        test_rmses.append(avg_test_rmse)

        print(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f},Train Proto Loss: {avg_train_proto_loss:.4f}")


        # Save checkpoint
        # save_checkpoint(model, optimizer, epoch, f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth", train_loss=avg_train_loss, val_loss=avg_val_loss, test_loss=avg_test_loss)

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    print(f"Total training time: {time.time() - start_time:.2f} seconds")
    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses







In [8]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model




In [9]:
import matplotlib.pyplot as plt
import numpy as np
import os
import h5py
from tqdm.notebook import tqdm
import pandas as pd
import csv

all_subjects= [f"subject_{x}" for x in range(1,14)]
input_acc, input_gyr, input_emg = 18,18,3
batch_size = 64

# Placeholder for storing best RMSEs
best_rmse_per_subject = []
best_pcc_per_subject = []

train_flag = True
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

for test_subject in all_subjects:



    print(f"Running training with {test_subject} as the test subject.")

    # Set up the training subjects (all except the test subject)
    train_subjects = [subject for subject in all_subjects if subject != test_subject]

    model_name = f'TeacherModel_RMSELoss_test_{test_subject}_wl{100}_ol{75}_prototypical'
    print(f"Model: {model_name}")

    # Load the model configuration and data loaders
    model_config = {
        'model': create_teacher_model(input_acc, input_gyr, input_emg, w=100),
        'loss': RMSELoss(),
        'loaders': create_base_data_loaders(
            config=config,
            train_subjects=train_subjects,
            test_subjects=[test_subject],
            window_length=100,
            window_overlap=75,
            batch_size=batch_size
        ),
        'epochs': 10,
        'use_curriculum': False
    }

    model = model_config['model']
    loss_function = model_config['loss']
    epochs = model_config.get("epochs", 10)
    device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    learn_rate = model_config.get("learn_rate", 0.001)
    use_curriculum = model_config.get("use_curriculum", False)

    optimizer = model_config.get("optimizer", None)
    l1_lambda = model_config.get("l1_lambda", None)

    print(f"Running model: {model_name}")

    # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
    train_loader, val_loader, test_loader = model_config['loaders']
    if train_flag:
    # Train the model and save only the best based on validation loss
      model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses = train_teacher(
          device=device,
          train_loader=train_loader,
          val_loader=val_loader,
          test_loader=test_loader,
          learn_rate=learn_rate,
          epochs=epochs,
          model=model,
          filename=model_name,
          loss_function=loss_function,
          optimizer=optimizer,
          l1_lambda=l1_lambda,
          train_from_last_epoch=False,
          proto_loss_function=PrototypicalLoss()
      )
    else:
      #load filename as model
      model.load_state_dict(torch.load(f"{model_name}"))
      model.to(device)
      model.eval()

     #run model on test set and record result
    test_loss, test_pcc, test_rmse = evaluate_model(device, model, test_loader, loss_function)
    print(f"Test Loss: {test_loss:.4f}, Test PCC: {np.mean(test_pcc):.4f}, Test RMSE: {np.mean(test_rmse):.4f}")
    best_rmse_per_subject.append(np.mean(test_rmse))
    best_pcc_per_subject.append(np.mean(test_pcc))


# Compute the average of the best RMSEs across all subjects



Running training with subject_1 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_prototypical
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Running model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.4491, Val Loss: 9.8600, Test Loss: 19.5966,Train Proto Loss: 0.4285


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 2, Train Loss: 9.9532, Val Loss: 7.8405, Test Loss: 18.1182,Train Proto Loss: 0.0636


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.8040, Val Loss: 7.6523, Test Loss: 17.6844,Train Proto Loss: 0.0333


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.2630, Val Loss: 7.0239, Test Loss: 18.1038,Train Proto Loss: 0.0308


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.7596, Val Loss: 6.9004, Test Loss: 20.5294,Train Proto Loss: 0.0206


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.3641, Val Loss: 6.8190, Test Loss: 20.5541,Train Proto Loss: 0.0375


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.1456, Val Loss: 6.3042, Test Loss: 17.9730,Train Proto Loss: 0.0360


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 8, Train Loss: 6.8427, Val Loss: 5.8364, Test Loss: 19.7625,Train Proto Loss: 0.0212


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.5159, Val Loss: 5.9542, Test Loss: 20.1108,Train Proto Loss: 0.0187


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.3786, Val Loss: 5.5514, Test Loss: 18.2890,Train Proto Loss: 0.0196
Total training time: 1414.83 seconds
Test Loss: 18.2890, Test PCC: 0.7242, Test RMSE: 17.7105
Running training with subject_2 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_prototypical
Sharded data found at /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_2. Skipping resharding.
Running model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 17.0456, Val Loss: 9.1254, Test Loss: 21.6439,Train Proto Loss: 0.4102


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 8.6343, Val Loss: 7.3518, Test Loss: 22.0475,Train Proto Loss: 0.0654


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 7.4291, Val Loss: 7.0562, Test Loss: 23.7571,Train Proto Loss: 0.0270


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 6.9250, Val Loss: 6.3177, Test Loss: 22.5937,Train Proto Loss: 0.0246


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 6.5677, Val Loss: 6.1387, Test Loss: 23.2281,Train Proto Loss: 0.0257


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 6.2283, Val Loss: 6.0412, Test Loss: 22.0145,Train Proto Loss: 0.0256


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 6.1436, Val Loss: 5.8327, Test Loss: 22.6922,Train Proto Loss: 0.0356


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 5.9034, Val Loss: 5.9200, Test Loss: 22.4028,Train Proto Loss: 0.0206


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 5.6304, Val Loss: 5.3947, Test Loss: 22.1517,Train Proto Loss: 0.0177


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 5.5854, Val Loss: 5.3967, Test Loss: 22.8244,Train Proto Loss: 0.0126
Total training time: 1435.46 seconds
Test Loss: 22.8244, Test PCC: 0.6708, Test RMSE: 21.4859
Running training with subject_3 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_3. Resharding...
Processing subjects: ['subject_3'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_3/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_3/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 17.6638, Val Loss: 9.7019, Test Loss: 20.9844,Train Proto Loss: 0.4330


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 9.5053, Val Loss: 7.5444, Test Loss: 20.3940,Train Proto Loss: 0.0579


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.3292, Val Loss: 6.9606, Test Loss: 20.3958,Train Proto Loss: 0.0405


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 7.7289, Val Loss: 6.5268, Test Loss: 19.4395,Train Proto Loss: 0.0269


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.2760, Val Loss: 6.2243, Test Loss: 20.7044,Train Proto Loss: 0.0242


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 6.9327, Val Loss: 6.2929, Test Loss: 20.2123,Train Proto Loss: 0.0342


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 6.6723, Val Loss: 5.7846, Test Loss: 20.0904,Train Proto Loss: 0.0126


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 6.4253, Val Loss: 5.7759, Test Loss: 19.6713,Train Proto Loss: 0.0128


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.2190, Val Loss: 5.6089, Test Loss: 19.8460,Train Proto Loss: 0.0200


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 5.9492, Val Loss: 5.4351, Test Loss: 20.7206,Train Proto Loss: 0.0146
Total training time: 1435.01 seconds
Test Loss: 20.7206, Test PCC: 0.7525, Test RMSE: 18.8178
Running training with subject_4 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_4. Resharding...
Processing subjects: ['subject_4'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_4/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_4/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.0191, Val Loss: 9.9053, Test Loss: 18.0656,Train Proto Loss: 0.5844


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 2, Train Loss: 9.8724, Val Loss: 8.4525, Test Loss: 18.6033,Train Proto Loss: 0.0827


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.6777, Val Loss: 7.3942, Test Loss: 16.8647,Train Proto Loss: 0.0389


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.0282, Val Loss: 7.3968, Test Loss: 18.6200,Train Proto Loss: 0.0317


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.7593, Val Loss: 6.9623, Test Loss: 17.1005,Train Proto Loss: 0.0185


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.3543, Val Loss: 7.0741, Test Loss: 16.8866,Train Proto Loss: 0.0147


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.0612, Val Loss: 6.6938, Test Loss: 17.4665,Train Proto Loss: 0.0269


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 8, Train Loss: 6.8373, Val Loss: 6.1105, Test Loss: 17.1544,Train Proto Loss: 0.0252


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.4672, Val Loss: 5.9108, Test Loss: 17.4269,Train Proto Loss: 0.0172


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.2421, Val Loss: 5.9880, Test Loss: 16.8543,Train Proto Loss: 0.0228
Total training time: 1444.25 seconds
Test Loss: 16.8543, Test PCC: 0.7580, Test RMSE: 15.9918
Running training with subject_5 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_5. Resharding...
Processing subjects: ['subject_5'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_5/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_5/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.0294, Val Loss: 11.0655, Test Loss: 25.2566,Train Proto Loss: 0.5640


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 2, Train Loss: 9.9685, Val Loss: 9.6530, Test Loss: 25.3385,Train Proto Loss: 0.0843


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.9610, Val Loss: 8.4052, Test Loss: 25.1039,Train Proto Loss: 0.0422


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.2333, Val Loss: 8.0603, Test Loss: 22.7254,Train Proto Loss: 0.0365


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.8155, Val Loss: 7.4632, Test Loss: 21.9208,Train Proto Loss: 0.0162


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.4863, Val Loss: 7.1832, Test Loss: 21.4899,Train Proto Loss: 0.0216


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.1909, Val Loss: 7.4781, Test Loss: 21.5923,Train Proto Loss: 0.0332


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 8, Train Loss: 6.9964, Val Loss: 7.1291, Test Loss: 21.1965,Train Proto Loss: 0.0213


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.6584, Val Loss: 6.6742, Test Loss: 21.6286,Train Proto Loss: 0.0136


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.3673, Val Loss: 6.5290, Test Loss: 23.9580,Train Proto Loss: 0.0306
Total training time: 1424.42 seconds
Test Loss: 23.9580, Test PCC: 0.6650, Test RMSE: 22.7182
Running training with subject_6 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_6. Resharding...
Processing subjects: ['subject_6'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_6/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_6/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.3311, Val Loss: 10.5960, Test Loss: 13.4515,Train Proto Loss: 0.4945


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.0028, Val Loss: 8.6804, Test Loss: 13.9916,Train Proto Loss: 0.0895


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.8360, Val Loss: 7.9388, Test Loss: 13.4219,Train Proto Loss: 0.0313


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.1485, Val Loss: 7.5402, Test Loss: 13.4832,Train Proto Loss: 0.0262


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.7914, Val Loss: 7.1604, Test Loss: 13.8118,Train Proto Loss: 0.0205


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.4577, Val Loss: 6.6914, Test Loss: 13.1820,Train Proto Loss: 0.0192


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.0668, Val Loss: 6.7526, Test Loss: 13.4898,Train Proto Loss: 0.0170


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 6.8231, Val Loss: 6.5779, Test Loss: 13.7739,Train Proto Loss: 0.0351


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.5130, Val Loss: 6.5489, Test Loss: 13.6136,Train Proto Loss: 0.0097


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.3740, Val Loss: 6.8380, Test Loss: 12.6295,Train Proto Loss: 0.0194
Total training time: 1411.83 seconds
Test Loss: 12.6295, Test PCC: 0.8353, Test RMSE: 11.8802
Running training with subject_7 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_7. Resharding...
Processing subjects: ['subject_7'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_7/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_7/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.0152, Val Loss: 10.1858, Test Loss: 17.8815,Train Proto Loss: 0.5443


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.0359, Val Loss: 8.6696, Test Loss: 17.4456,Train Proto Loss: 0.0554


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.9432, Val Loss: 7.7090, Test Loss: 16.2849,Train Proto Loss: 0.0333


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.3091, Val Loss: 7.4321, Test Loss: 16.1678,Train Proto Loss: 0.0291


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.9071, Val Loss: 7.3094, Test Loss: 16.9253,Train Proto Loss: 0.0272


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.5376, Val Loss: 6.6806, Test Loss: 15.3893,Train Proto Loss: 0.0230


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.2512, Val Loss: 6.5790, Test Loss: 13.9346,Train Proto Loss: 0.0236


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 7.0316, Val Loss: 6.0122, Test Loss: 15.1645,Train Proto Loss: 0.0156


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.7900, Val Loss: 5.9397, Test Loss: 15.0568,Train Proto Loss: 0.0356


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.4807, Val Loss: 6.1229, Test Loss: 13.8262,Train Proto Loss: 0.0207
Total training time: 1452.24 seconds
Test Loss: 13.8262, Test PCC: 0.7963, Test RMSE: 13.1477
Running training with subject_8 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_8. Resharding...
Processing subjects: ['subject_8'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_8/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_8/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.4904, Val Loss: 11.2696, Test Loss: 13.4264,Train Proto Loss: 0.4701


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.2330, Val Loss: 9.5780, Test Loss: 12.6266,Train Proto Loss: 0.0847


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.9677, Val Loss: 8.4879, Test Loss: 11.4731,Train Proto Loss: 0.0337


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.3247, Val Loss: 7.7371, Test Loss: 12.1599,Train Proto Loss: 0.0283


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.9528, Val Loss: 7.2451, Test Loss: 11.1896,Train Proto Loss: 0.0348


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.6267, Val Loss: 7.4450, Test Loss: 11.9827,Train Proto Loss: 0.0397


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.3844, Val Loss: 7.0566, Test Loss: 11.7193,Train Proto Loss: 0.0199


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 7.0882, Val Loss: 6.6943, Test Loss: 11.2057,Train Proto Loss: 0.0258


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.7803, Val Loss: 6.5152, Test Loss: 11.8107,Train Proto Loss: 0.0110


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.5025, Val Loss: 6.5133, Test Loss: 12.1379,Train Proto Loss: 0.0104
Total training time: 1473.97 seconds
Test Loss: 12.1379, Test PCC: 0.8075, Test RMSE: 11.5827
Running training with subject_9 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_9. Resharding...
Processing subjects: ['subject_9'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_9/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_9/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.6991, Val Loss: 10.6395, Test Loss: 11.1988,Train Proto Loss: 0.4199


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.1573, Val Loss: 9.0944, Test Loss: 10.4217,Train Proto Loss: 0.0893


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 9.1166, Val Loss: 8.3054, Test Loss: 9.7410,Train Proto Loss: 0.0392


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.4455, Val Loss: 7.8926, Test Loss: 10.4409,Train Proto Loss: 0.0317


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.9614, Val Loss: 7.6526, Test Loss: 10.0446,Train Proto Loss: 0.0275


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.5685, Val Loss: 6.6719, Test Loss: 10.5007,Train Proto Loss: 0.0205


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.2693, Val Loss: 6.8542, Test Loss: 11.1375,Train Proto Loss: 0.0250


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 7.1724, Val Loss: 6.8097, Test Loss: 9.0664,Train Proto Loss: 0.0390


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.8260, Val Loss: 6.2612, Test Loss: 9.5511,Train Proto Loss: 0.0185


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.6085, Val Loss: 6.2044, Test Loss: 9.7444,Train Proto Loss: 0.0099
Total training time: 1467.14 seconds
Test Loss: 9.7444, Test PCC: 0.8218, Test RMSE: 9.6158
Running training with subject_10 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_10. Resharding...
Processing subjects: ['subject_10'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_10/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_10/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.5206, Val Loss: 10.4395, Test Loss: 18.5743,Train Proto Loss: 0.4402


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 9.9688, Val Loss: 8.4204, Test Loss: 19.3469,Train Proto Loss: 0.0680


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.9937, Val Loss: 8.3944, Test Loss: 17.3859,Train Proto Loss: 0.0305


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.3205, Val Loss: 7.5774, Test Loss: 19.4446,Train Proto Loss: 0.0344


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.9150, Val Loss: 7.3904, Test Loss: 18.5799,Train Proto Loss: 0.0190


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.6844, Val Loss: 6.8542, Test Loss: 16.3524,Train Proto Loss: 0.0279


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.3206, Val Loss: 6.9077, Test Loss: 16.8928,Train Proto Loss: 0.0322


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 7.0769, Val Loss: 6.5439, Test Loss: 19.6480,Train Proto Loss: 0.0250


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.8454, Val Loss: 5.9976, Test Loss: 19.1440,Train Proto Loss: 0.0232


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.5871, Val Loss: 6.0797, Test Loss: 18.7619,Train Proto Loss: 0.0289
Total training time: 1457.62 seconds
Test Loss: 18.7619, Test PCC: 0.7552, Test RMSE: 17.2733
Running training with subject_11 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_11. Resharding...
Processing subjects: ['subject_11'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_11/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_11/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.2905, Val Loss: 10.6179, Test Loss: 17.8112,Train Proto Loss: 0.5123


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.1684, Val Loss: 9.4010, Test Loss: 16.3658,Train Proto Loss: 0.0655


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.9506, Val Loss: 8.1439, Test Loss: 14.4575,Train Proto Loss: 0.0315


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.3745, Val Loss: 7.8644, Test Loss: 15.5811,Train Proto Loss: 0.0224


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 8.0022, Val Loss: 7.1983, Test Loss: 15.6933,Train Proto Loss: 0.0305


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.6101, Val Loss: 7.0608, Test Loss: 14.5918,Train Proto Loss: 0.0289


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.3709, Val Loss: 6.6085, Test Loss: 15.4597,Train Proto Loss: 0.0424


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 7.0732, Val Loss: 6.4103, Test Loss: 15.4915,Train Proto Loss: 0.0169


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.7806, Val Loss: 6.1860, Test Loss: 14.9552,Train Proto Loss: 0.0188


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.5306, Val Loss: 5.9042, Test Loss: 15.3494,Train Proto Loss: 0.0119
Total training time: 1449.74 seconds
Test Loss: 15.3494, Test PCC: 0.7648, Test RMSE: 14.5211
Running training with subject_12 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_12. Resharding...
Processing subjects: ['subject_12'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_12/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_12/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.7560, Val Loss: 10.3694, Test Loss: 14.5072,Train Proto Loss: 0.4882


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.1479, Val Loss: 8.3633, Test Loss: 13.3459,Train Proto Loss: 0.0706


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 9.0545, Val Loss: 8.2180, Test Loss: 14.7222,Train Proto Loss: 0.0322


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.4400, Val Loss: 7.5501, Test Loss: 12.4329,Train Proto Loss: 0.0231


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.9011, Val Loss: 7.1669, Test Loss: 13.9477,Train Proto Loss: 0.0221


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.5503, Val Loss: 6.8011, Test Loss: 13.7636,Train Proto Loss: 0.0267


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.3402, Val Loss: 6.9944, Test Loss: 14.8824,Train Proto Loss: 0.0255


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 7.0347, Val Loss: 6.4010, Test Loss: 14.9871,Train Proto Loss: 0.0141


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.7402, Val Loss: 6.2367, Test Loss: 14.9991,Train Proto Loss: 0.0193


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.4534, Val Loss: 6.3212, Test Loss: 15.2217,Train Proto Loss: 0.0142
Total training time: 1444.48 seconds
Test Loss: 15.2217, Test PCC: 0.7747, Test RMSE: 14.2038
Running training with subject_13 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_prototypical
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_13. Resharding...
Processing subjects: ['subject_13'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_13/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_13/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_prototypical
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 1, Train Loss: 18.1434, Val Loss: 10.2394, Test Loss: 18.7535,Train Proto Loss: 0.4662


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 2, Train Loss: 10.0108, Val Loss: 9.2404, Test Loss: 16.3812,Train Proto Loss: 0.0805


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 3, Train Loss: 8.8974, Val Loss: 8.0644, Test Loss: 15.9724,Train Proto Loss: 0.0349


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 4, Train Loss: 8.2146, Val Loss: 7.6127, Test Loss: 17.0211,Train Proto Loss: 0.0246


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 5, Train Loss: 7.8565, Val Loss: 7.3018, Test Loss: 16.0658,Train Proto Loss: 0.0292


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 6, Train Loss: 7.4081, Val Loss: 7.2592, Test Loss: 16.4790,Train Proto Loss: 0.0259


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 7, Train Loss: 7.2004, Val Loss: 6.7336, Test Loss: 15.1348,Train Proto Loss: 0.0175


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 8, Train Loss: 6.9482, Val Loss: 6.6834, Test Loss: 16.8926,Train Proto Loss: 0.0224


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 9, Train Loss: 6.6369, Val Loss: 6.3872, Test Loss: 17.0528,Train Proto Loss: 0.0204


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch 10, Train Loss: 6.4076, Val Loss: 6.1139, Test Loss: 16.5545,Train Proto Loss: 0.0283
Total training time: 1461.54 seconds
Test Loss: 16.5545, Test PCC: 0.7516, Test RMSE: 14.7681


In [10]:

average_best_rmse = np.mean(best_rmse_per_subject)
average_best_pcc = np.mean(best_pcc_per_subject)
print(f"Average of best RMSEs across all subjects: {average_best_rmse:.4f}")
print(f"Average of best PCCs across all subjects: {average_best_pcc:.4f}")
print(best_rmse_per_subject)
print(best_pcc_per_subject)

# subjects = [f'Subject {i+1}' for i in range(len(best_rmse_per_subject))]

# print(best_rmse_per_subject)
# # Plot a bar chart with subject labels on the x-axis
# plt.figure(figsize=(10, 6))
# plt.bar(subjects, best_rmse_per_subject, color='blue', edgecolor='black')
# plt.title('Best RMSEs for Each Subject')
# plt.xlabel('Subjects')
# plt.ylabel('Best RMSE')
# plt.xticks(rotation=45, ha='right')
# plt.grid(True, axis='y')
# plt.tight_layout()
# plt.show()

Average of best RMSEs across all subjects: 15.6705
Average of best PCCs across all subjects: 0.7598
[17.710479935010273, 21.485891044139862, 18.817767123381298, 15.99175876379013, 22.71821077664693, 11.880238751570383, 13.147701899210611, 11.582688291867575, 9.61577981710434, 17.273259421189625, 14.521135548750559, 14.203756163517633, 14.768128951390585]
[0.7241782482583293, 0.6707674041331931, 0.7524613191888972, 0.7579811203357275, 0.66497124475815, 0.835292968350339, 0.796261584338935, 0.8075211507520779, 0.8218243771179395, 0.7552485423200483, 0.7647565566956033, 0.7746728981276257, 0.7515742181319682]


In [11]:
import os
import zipfile
from datetime import datetime

notebook_name = 'regression_benchmark_prototypical'

# Create a timestamped folder name based on the notebook name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"{notebook_name}_checkpoints_{timestamp}"

# Make sure the folder exists
os.makedirs(folder_name, exist_ok=True)

checkpoint_dir = '.'

# Zip all checkpoint files and save in the new folder
zip_filename = f"{folder_name}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # List files only in the current directory (no subfolders)
    for file in os.listdir(checkpoint_dir):
        if "TeacherModel" in str(file):
          file_path = os.path.join(checkpoint_dir, file)
          zipf.write(file_path, os.path.relpath(file_path, checkpoint_dir))
          print(f"Checkpoint {file} has been added to the zip file.")
print(f"All checkpoints have been zipped and saved as {zip_filename}.")




Checkpoint TeacherModel_RMSELoss_test_subject_11_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_2_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_5_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_8_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_12_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_1_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_10_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_13_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_4_wl100_ol75_prototypical has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_7_wl100_ol75_p

In [12]:
# Download the zip file to your local machine
from google.colab import files
files.download(zip_filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>