In [180]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Drive already mounted at /content/MyDrive; to attempt to forcibly remount, call drive.mount("/content/MyDrive", force_remount=True).


In [181]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data_final.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [182]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
      extracted_data = []
      new_column_names = []  # Initialize here

      if isinstance(session_data, h5py.Dataset):
          if session_data.dtype.names:
              # Compound dataset
              column_names = session_data.dtype.names
              for channel in channels:
                  if channel in column_names:
                      channel_data = session_data[channel][:]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)  # Populate here
                  else:
                      print(f"Channel {channel} not found in compound dataset.")
          else:
              # Simple dataset
              column_names = list(session_data.attrs.get('column_names', []))
              assert len(column_names) > 0, "column_names not found in dataset attributes"
              for channel in channels:
                  if channel in column_names:
                      col_idx = column_names.index(channel)
                      channel_data = session_data[:, col_idx]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)
                  else:
                      print(f"Channel {channel} not found in session data.")

      return np.array(extracted_data), new_column_names


In [183]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader



In [184]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
from mamba_ssm import Mamba, Mamba2

class Encoder_Mamba(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_Mamba, self).__init__()

        # Project to 128 so that the rest of the Mamba block
        # will indeed produce 128 features.
        self.input_projection = nn.Linear(input_dim, 128)

        self.mamba_1 = Mamba(
            d_model=128,
            d_state=128,
            d_conv=4,
            expand=2
        )
        self.mamba_2 = Mamba(
            d_model=128,
            d_state=128,
            d_conv=4,
            expand=2
        )

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        # Project input_dim -> 128
        x = self.input_projection(x)

        out_1 = self.mamba_1(x)
        out_1 = self.dropout_1(out_1)

        out_2 = self.mamba_2(out_1)
        out_2 = self.dropout_2(out_2)

        return out_2, out_1


class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100):
        super(teacher, self).__init__()

        self.w = w
        self.encoder_acc = Encoder_Mamba(input_acc, drop_prob)
        self.encoder_gyr = Encoder_Mamba(input_gyr, drop_prob)
        self.encoder_emg = Encoder_Mamba(input_emg, drop_prob)

        self.BN_acc = nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr = nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg = nn.BatchNorm1d(input_emg, affine=False)

        self.fc = nn.Linear(2 * 3 * 128 + 128, 3)
        self.dropout = nn.Dropout(p=0.05)

        self.fc_kd = nn.Linear(3 * 128, 2 * 128)

        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention = nn.MultiheadAttention(3 * 128, 4, batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128 * 3, 3 * 128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)

    def forward(self, x_acc, x_gyr, x_emg):
        x_acc_1 = x_acc.view(x_acc.size(0) * x_acc.size(1), x_acc.size(-1))
        x_gyr_1 = x_gyr.view(x_gyr.size(0) * x_gyr.size(1), x_gyr.size(-1))
        x_emg_1 = x_emg.view(x_emg.size(0) * x_emg.size(1), x_emg.size(-1))

        x_acc_1 = self.BN_acc(x_acc_1)
        x_gyr_1 = self.BN_gyr(x_gyr_1)
        x_emg_1 = self.BN_emg(x_emg_1)

        x_acc_2 = x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2 = x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2 = x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Pass through Encoder for each modality
        x_acc_1, x_acc_out1 = self.encoder_acc(x_acc_2)
        x_gyr_1, x_gyr_out1 = self.encoder_gyr(x_gyr_2)
        x_emg_1, x_emg_out1 = self.encoder_emg(x_emg_2)

        # Concatenate features
        x = torch.cat((x_acc_1, x_gyr_1, x_emg_1), dim=-1)
        x_kd = self.fc_kd(x)

        out_1, attn_output_weights = self.attention(x, x, x)

        gating_weights = self.gating_net(x)
        out_2 = gating_weights * x

        weights_1 = self.weighted_feat(x[:, :, 0:128])
        weights_2 = self.weighted_feat(x[:, :, 128:2 * 128])
        weights_3 = self.weighted_feat(x[:, :, 2 * 128:3 * 128])
        x_1 = weights_1 * x[:, :, 0:128]
        x_2 = weights_2 * x[:, :, 128:2 * 128]
        x_3 = weights_3 * x[:, :, 2 * 128 : 3 * 128]
        out_3 = x_1 + x_2 + x_3

        out = torch.cat((out_1, out_2, out_3), dim=-1)

        out = self.fc(out)

        return out, x_kd, (None)




In [185]:
# @title Loss Functions
import statistics

class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

In [186]:





# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse



def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(device, train_loader, val_loader, test_loader, learn_rate, epochs, model, filename, loss_function, optimizer=None, l1_lambda=None, train_from_last_epoch=False, curriculum_loader=None):
    model.to(device)
    criterion = loss_function

    if optimizer is None:
        # Create a default Adam optimizer if none is passed
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    train_pccs_channelwise = []
    val_pccs_channelwise = []
    test_pccs_channelwise = []

    train_rmses_channelwise = []
    val_rmses_channelwise = []
    test_rmses_channelwise = []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    if train_from_last_epoch and os.path.exists(checkpoint_path):
        # Scan for the latest saved checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch number
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']  # Continue from the next epoch

            # Load the history from checkpoint
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
            train_pccs = checkpoint['history']['train_pccs']
            val_pccs = checkpoint['history']['val_pccs']
            test_pccs = checkpoint['history']['test_pccs']
            train_rmses = checkpoint['history']['train_rmses']
            val_rmses = checkpoint['history']['val_rmses']
            test_rmses = checkpoint['history']['test_rmses']
            train_pccs_channelwise = checkpoint['history']['train_pccs_channelwise']
            val_pccs_channelwise = checkpoint['history']['val_pccs_channelwise']
            test_pccs_channelwise = checkpoint['history']['test_pccs_channelwise']
            train_rmses_channelwise = checkpoint['history']['train_rmses_channelwise']
            val_rmses_channelwise = checkpoint['history']['val_rmses_channelwise']
            test_rmses_channelwise = checkpoint['history']['test_rmses_channelwise']
            if 'curriculum_schedule' in checkpoint:
                curriculum_loader.curriculum_schedule = checkpoint['curriculum_schedule']  # Load saved curriculum schedule
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(last_epoch, epochs):
        epoch_start_time = time.time()
        model.train()

        if curriculum_loader:
            curriculum_loader.update_epoch(epoch)
            train_loader, val_loader, test_loader = curriculum_loader.get_loaders()
        # Track metrics per channel
        epoch_train_loss = np.zeros(len(config.channels_joints))
        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        # Use epoch starting from `epoch + 1` since we want to reflect actual starting epoch correctly
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            # Ensure inputs are properly sent to device and are of correct type
            output = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())


            # Check if output is a tuple, take the first element if true
            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())

            else:
                loss = criterion(output, target.to(device).float())

            # Compute loss



            # Apply L1 regularization if specified
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss = loss + l1_lambda * l1_norm
            else:
                total_loss = loss

            # Backpropagate the gradients for total_loss
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Detach tensors and move to CPU to prevent issues with gradient computation
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)

            # Accumulate loss, pcc, and rmse without modifying in-place
            epoch_train_loss += loss.detach().cpu().numpy()
            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_loss)
        train_pccs.append(np.mean(avg_train_pcc))  # Overall average PCC
        train_rmses.append(np.mean(avg_train_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        train_pccs_channelwise.append(avg_train_pcc)  # Per channel
        train_rmses_channelwise.append(avg_train_rmse)  # Per channel

        # Evaluate on validation set every epoch
        avg_val_loss, avg_val_pcc, avg_val_rmse = evaluate_model(device, model, val_loader, criterion)
        val_losses.append(avg_val_loss)
        val_pccs.append(np.mean(avg_val_pcc))  # Overall average PCC
        val_rmses.append(np.mean(avg_val_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        val_pccs_channelwise.append(avg_val_pcc)  # Per channel
        val_rmses_channelwise.append(avg_val_rmse)  # Per channel

        # Evaluate on test set and checkpoint every epoch
        avg_test_loss, avg_test_pcc, avg_test_rmse = evaluate_model(device, model, test_loader, criterion)
        test_losses.append(avg_test_loss)
        test_pccs.append(np.mean(avg_test_pcc))  # Overall average PCC
        test_rmses.append(np.mean(avg_test_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        test_pccs_channelwise.append(avg_test_pcc)  # Per channel
        test_rmses_channelwise.append(avg_test_rmse)  # Per channel

        print(f"Epoch: {epoch + 1}, Training Loss: {np.mean(avg_train_loss):.4f}, Validation Loss: {np.mean(avg_val_loss):.4f}, Test Loss: {np.mean(avg_test_loss):.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse)}, Validation RMSE: {np.mean(avg_val_rmse):.4f}, Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc)}, Validation PCC: {np.mean(avg_val_pcc):.4f}, Test PCC: {np.mean(avg_test_pcc):.4f}")

        # Save checkpoint, including curriculum schedule
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save checkpoint with the curriculum schedule
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'test_losses': test_losses,
            'train_pccs': train_pccs,
            'val_pccs': val_pccs,
            'test_pccs': test_pccs,
            'train_rmses': train_rmses,
            'val_rmses': val_rmses,
            'test_rmses': test_rmses,
            'train_pccs_channelwise': train_pccs_channelwise,
            'val_pccs_channelwise': val_pccs_channelwise,
            'test_pccs_channelwise': test_pccs_channelwise,
            'train_rmses_channelwise': train_rmses_channelwise,
            'val_rmses_channelwise': val_rmses_channelwise,
            'test_rmses_channelwise': test_rmses_channelwise
        }

        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth",
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            test_loss=avg_test_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            },
            history=history,  # Save history in the checkpoint
            curriculum_schedule=curriculum_loader.curriculum_schedule if curriculum_loader else None # Save curriculum schedule
        )

        # Early stopping logic
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    print(f"loading best model from {filename}")
    model.load_state_dict(torch.load(filename))
    model.eval()
    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses






In [187]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model




In [188]:
import matplotlib.pyplot as plt
import numpy as np
import os
import h5py
from tqdm.notebook import tqdm
import pandas as pd
import csv

all_subjects= [f"subject_{x}" for x in range(1,14)]
input_acc, input_gyr, input_emg = 18,18,3
batch_size = 64

# Placeholder for storing best RMSEs
best_rmse_per_subject = []
best_pcc_per_subject = []

train_flag = True

for test_subject in all_subjects:



    print(f"Running training with {test_subject} as the test subject.")

    # Set up the training subjects (all except the test subject)
    train_subjects = [subject for subject in all_subjects if subject != test_subject]

    model_name = f'TeacherModel_RMSELoss_test_{test_subject}_wl{100}_ol{75}_mamba'
    print(f"Model: {model_name}")

    # Load the model configuration and data loaders
    model_config = {
        'model': create_teacher_model(input_acc, input_gyr, input_emg, w=100),
        'loss': RMSELoss(),
        'loaders': create_base_data_loaders(
            config=config,
            train_subjects=train_subjects,
            test_subjects=[test_subject],
            window_length=100,
            window_overlap=75,
            batch_size=batch_size
        ),
        'epochs': 10,
        'use_curriculum': False
    }

    model = model_config['model']
    loss_function = model_config['loss']
    epochs = model_config.get("epochs", 10)
    device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    learn_rate = model_config.get("learn_rate", 0.0001)
    use_curriculum = model_config.get("use_curriculum", False)

    optimizer = model_config.get("optimizer", None)
    l1_lambda = model_config.get("l1_lambda", None)

    print(f"Running model: {model_name}")

    # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
    train_loader, val_loader, test_loader = model_config['loaders']
    if train_flag:
    # Train the model and save only the best based on validation loss
      model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses = train_teacher(
          device=device,
          train_loader=train_loader,
          val_loader=val_loader,
          test_loader=test_loader,
          learn_rate=learn_rate,
          epochs=epochs,
          model=model,
          filename=model_name,
          loss_function=loss_function,
          optimizer=optimizer,
          l1_lambda=l1_lambda,
          train_from_last_epoch=False
      )
    else:
      #load filename as model
      model.load_state_dict(torch.load(f"{model_name}"))
      model.to(device)
      model.eval()

     #run model on test set and record result
    test_loss, test_pcc, test_rmse = evaluate_model(device, model, test_loader, loss_function)
    print(f"Test Loss: {test_loss:.4f}, Test PCC: {np.mean(test_pcc):.4f}, Test RMSE: {np.mean(test_rmse):.4f}")
    best_rmse_per_subject.append(np.mean(test_rmse))
    best_pcc_per_subject.append(np.mean(test_pcc))


# Compute the average of the best RMSEs across all subjects



Running training with subject_1 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_mamba
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Running model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 39.9403, Validation Loss: 23.6758, Test Loss: 21.2814
Training RMSE: 38.58785808034141, Validation RMSE: 23.2739, Test RMSE: 19.9437
Training PCC: 0.3832975459711055, Validation PCC: 0.6318, Test PCC: 0.5049
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.9996, Validation Loss: 18.7052, Test Loss: 18.8123
Training RMSE: 20.726007951582176, Validation RMSE: 18.4585, Test RMSE: 17.6504
Training PCC: 0.6553131493053485, Validation PCC: 0.7176, Test PCC: 0.5955
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 17.8014, Validation Loss: 16.6524, Test Loss: 17.6357
Training RMSE: 17.535300359781306, Validation RMSE: 16.3876, Test RMSE: 16.4134
Training PCC: 0.7621620401077488, Validation PCC: 0.7887, Test PCC: 0.6184
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.9462, Validation Loss: 15.0527, Test Loss: 17.9886
Training RMSE: 15.707438360450453, Validation RMSE: 14.7921, Test RMSE: 16.6903
Training PCC: 0.8174701144854412, Validation PCC: 0.8542, Test PCC: 0.6772
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 15.0231, Validation Loss: 13.1312, Test Loss: 17.3425
Training RMSE: 14.638218684402425, Validation RMSE: 12.8439, Test RMSE: 16.3120
Training PCC: 0.8703284995385969, Validation PCC: 0.9003, Test PCC: 0.6932
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.2310, Validation Loss: 12.8948, Test Loss: 16.9791
Training RMSE: 12.903854206470557, Validation RMSE: 12.5212, Test RMSE: 16.0420
Training PCC: 0.9040735715613422, Validation PCC: 0.9130, Test PCC: 0.7065
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.4168, Validation Loss: 11.6775, Test Loss: 17.7475
Training RMSE: 12.061421559923447, Validation RMSE: 11.3374, Test RMSE: 16.9682
Training PCC: 0.9200544729113914, Validation PCC: 0.9348, Test PCC: 0.7027
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 11.8116, Validation Loss: 11.1137, Test Loss: 17.2655
Training RMSE: 11.446219039591687, Validation RMSE: 10.6955, Test RMSE: 16.6695
Training PCC: 0.9294827541010093, Validation PCC: 0.9409, Test PCC: 0.6872
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.5337, Validation Loss: 10.6454, Test Loss: 18.1813
Training RMSE: 11.147780878321406, Validation RMSE: 10.2531, Test RMSE: 17.5917
Training PCC: 0.9345251930134877, Validation PCC: 0.9461, Test PCC: 0.6825
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 10.9933, Validation Loss: 10.4684, Test Loss: 18.2067
Training RMSE: 10.60478535742222, Validation RMSE: 10.0399, Test RMSE: 17.5806
Training PCC: 0.941322293498541, Validation PCC: 0.9498, Test PCC: 0.6765
Checkpoint saved for epoch 10
Total training time: 1764.73 seconds
loading best model from TeacherModel_RMSELoss_test_subject_1_wl100_ol75_mamba
Test Loss: 18.2067, Test PCC: 0.6765, Test RMSE: 17.5806
Running training with subject_2 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /c

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_2. Resharding...
Processing subjects: ['subject_2'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_2/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_2/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 36.5108, Validation Loss: 24.2796, Test Loss: 27.8977
Training RMSE: 35.15029430560404, Validation RMSE: 23.9569, Test RMSE: 25.3377
Training PCC: 0.37900537349965585, Validation PCC: 0.6265, Test PCC: 0.4140
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.7693, Validation Loss: 17.9583, Test Loss: 23.8413
Training RMSE: 20.466806016478323, Validation RMSE: 17.6803, Test RMSE: 22.0547
Training PCC: 0.6785990539931434, Validation PCC: 0.7718, Test PCC: 0.5185
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 17.3855, Validation Loss: 16.2125, Test Loss: 24.4393
Training RMSE: 17.06065666768244, Validation RMSE: 15.8623, Test RMSE: 22.9177
Training PCC: 0.7780299694317474, Validation PCC: 0.8199, Test PCC: 0.5472
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.5624, Validation Loss: 14.1261, Test Loss: 23.8462
Training RMSE: 15.247446224206463, Validation RMSE: 13.8726, Test RMSE: 22.3876
Training PCC: 0.8389081114612131, Validation PCC: 0.8773, Test PCC: 0.5978
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.0233, Validation Loss: 13.1989, Test Loss: 23.4047
Training RMSE: 13.67719966339301, Validation RMSE: 12.9285, Test RMSE: 21.6160
Training PCC: 0.8841271104815767, Validation PCC: 0.9028, Test PCC: 0.6418
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 12.7395, Validation Loss: 11.9918, Test Loss: 23.4851
Training RMSE: 12.402649012725282, Validation RMSE: 11.6529, Test RMSE: 21.9364
Training PCC: 0.9095919486672948, Validation PCC: 0.9248, Test PCC: 0.6599
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 11.9042, Validation Loss: 10.4919, Test Loss: 24.4952
Training RMSE: 11.566342037414907, Validation RMSE: 10.2298, Test RMSE: 22.6938
Training PCC: 0.9243120980880047, Validation PCC: 0.9380, Test PCC: 0.6741
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 11.2124, Validation Loss: 9.7665, Test Loss: 23.6241
Training RMSE: 10.873547328774585, Validation RMSE: 9.5023, Test RMSE: 21.9611
Training PCC: 0.9341796762898209, Validation PCC: 0.9472, Test PCC: 0.6801
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 10.7195, Validation Loss: 10.0536, Test Loss: 23.6565
Training RMSE: 10.37965397883496, Validation RMSE: 9.7254, Test RMSE: 22.1035
Training PCC: 0.9411417693634151, Validation PCC: 0.9512, Test PCC: 0.6718
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 10.2476, Validation Loss: 9.2566, Test Loss: 24.1497
Training RMSE: 9.92397108189511, Validation RMSE: 8.9692, Test RMSE: 22.4303
Training PCC: 0.9463721696376659, Validation PCC: 0.9565, Test PCC: 0.6771
Checkpoint saved for epoch 10
Total training time: 1758.99 seconds
loading best model from TeacherModel_RMSELoss_test_subject_2_wl100_ol75_mamba
Test Loss: 24.1497, Test PCC: 0.6771, Test RMSE: 22.4303
Running training with subject_3 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /con

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_3. Resharding...
Processing subjects: ['subject_3'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_3/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_3/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 41.3573, Validation Loss: 23.7384, Test Loss: 27.7281
Training RMSE: 39.914721597877254, Validation RMSE: 23.3820, Test RMSE: 25.1980
Training PCC: 0.34527013501580006, Validation PCC: 0.6267, Test PCC: 0.4661
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.8090, Validation Loss: 17.7918, Test Loss: 25.6876
Training RMSE: 20.59178033512941, Validation RMSE: 17.6307, Test RMSE: 23.7121
Training PCC: 0.6486863899459036, Validation PCC: 0.7361, Test PCC: 0.5048
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 16.9916, Validation Loss: 15.8719, Test Loss: 25.1676
Training RMSE: 16.816320251097647, Validation RMSE: 15.7557, Test RMSE: 23.0341
Training PCC: 0.764607759101223, Validation PCC: 0.8100, Test PCC: 0.5935
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.1666, Validation Loss: 14.5033, Test Loss: 24.0744
Training RMSE: 15.005542095984586, Validation RMSE: 14.3933, Test RMSE: 21.7341
Training PCC: 0.8318059945851753, Validation PCC: 0.8712, Test PCC: 0.6850
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 13.7817, Validation Loss: 12.9012, Test Loss: 21.9299
Training RMSE: 13.576373662231498, Validation RMSE: 12.6992, Test RMSE: 19.4947
Training PCC: 0.8817722584084601, Validation PCC: 0.9046, Test PCC: 0.7207
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 12.7561, Validation Loss: 11.5290, Test Loss: 20.9603
Training RMSE: 12.519222905182138, Validation RMSE: 11.3202, Test RMSE: 18.7059
Training PCC: 0.9049997241029749, Validation PCC: 0.9233, Test PCC: 0.7560
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 11.9479, Validation Loss: 11.0225, Test Loss: 20.3764
Training RMSE: 11.713436683500873, Validation RMSE: 10.7680, Test RMSE: 17.9475
Training PCC: 0.919875821553509, Validation PCC: 0.9337, Test PCC: 0.7636
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 11.4066, Validation Loss: 10.7277, Test Loss: 19.9518
Training RMSE: 11.163588314384453, Validation RMSE: 10.4062, Test RMSE: 17.3721
Training PCC: 0.9282956431120596, Validation PCC: 0.9401, Test PCC: 0.7651
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 10.9194, Validation Loss: 10.7069, Test Loss: 20.8721
Training RMSE: 10.659124462900259, Validation RMSE: 10.3936, Test RMSE: 18.4938
Training PCC: 0.9355730591822639, Validation PCC: 0.9465, Test PCC: 0.7692
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 10.5025, Validation Loss: 9.5055, Test Loss: 19.6651
Training RMSE: 10.248262727671628, Validation RMSE: 9.2350, Test RMSE: 17.2528
Training PCC: 0.9415944507058741, Validation PCC: 0.9518, Test PCC: 0.7695
Checkpoint saved for epoch 10
Total training time: 1774.25 seconds
loading best model from TeacherModel_RMSELoss_test_subject_3_wl100_ol75_mamba
Test Loss: 19.6651, Test PCC: 0.7695, Test RMSE: 17.2528
Running training with subject_4 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /c

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_4. Resharding...
Processing subjects: ['subject_4'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_4/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_4/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 37.2781, Validation Loss: 23.5994, Test Loss: 26.0519
Training RMSE: 35.87008844713458, Validation RMSE: 23.1256, Test RMSE: 24.6600
Training PCC: 0.3335055492238825, Validation PCC: 0.6296, Test PCC: 0.4466
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 21.7448, Validation Loss: 19.6621, Test Loss: 119251.3249
Training RMSE: 21.442147798682765, Validation RMSE: 19.4248, Test RMSE: 105632.9299
Training PCC: 0.6418850848600456, Validation PCC: 0.6640, Test PCC: 0.3383
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 18.3570, Validation Loss: 17.2552, Test Loss: 2070.9537
Training RMSE: 18.12969445570541, Validation RMSE: 16.8952, Test RMSE: 1693.4194
Training PCC: 0.7105732950746487, Validation PCC: 0.7699, Test PCC: 0.4292
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.8248, Validation Loss: 14.2477, Test Loss: 1060.4218
Training RMSE: 15.603568746209469, Validation RMSE: 13.9860, Test RMSE: 1037.8387
Training PCC: 0.8089957872632599, Validation PCC: 0.8555, Test PCC: 0.5093
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.1319, Validation Loss: 13.2933, Test Loss: 28.2003
Training RMSE: 13.868335748093969, Validation RMSE: 12.9192, Test RMSE: 26.8165
Training PCC: 0.8748661057769613, Validation PCC: 0.9004, Test PCC: 0.5785
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 12.9346, Validation Loss: 12.4321, Test Loss: 1752.4656
Training RMSE: 12.621262261483329, Validation RMSE: 11.9918, Test RMSE: 1734.1876
Training PCC: 0.9054267752921316, Validation PCC: 0.9209, Test PCC: 0.5623
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.1561, Validation Loss: 11.7288, Test Loss: 28.6627
Training RMSE: 11.811952757300816, Validation RMSE: 11.2657, Test RMSE: 27.4816
Training PCC: 0.9208842189329243, Validation PCC: 0.9304, Test PCC: 0.6031
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 11.5830, Validation Loss: 11.2009, Test Loss: 536.0985
Training RMSE: 11.223640251686634, Validation RMSE: 10.7273, Test RMSE: 528.4322
Training PCC: 0.930699378136932, Validation PCC: 0.9391, Test PCC: 0.5894
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.1986, Validation Loss: 10.6613, Test Loss: 70.0784
Training RMSE: 10.81569081497142, Validation RMSE: 10.2447, Test RMSE: 66.0349
Training PCC: 0.937582423702661, Validation PCC: 0.9442, Test PCC: 0.6067
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 10.7014, Validation Loss: 11.6996, Test Loss: 19434.1449
Training RMSE: 10.314670423049328, Validation RMSE: 11.0261, Test RMSE: 18622.3878
Training PCC: 0.9438169964717815, Validation PCC: 0.9481, Test PCC: 0.5991
Checkpoint saved for epoch 10
Total training time: 1775.43 seconds
loading best model from TeacherModel_RMSELoss_test_subject_4_wl100_ol75_mamba
Test Loss: 70.0784, Test PCC: 0.6067, Test RMSE: 66.0349
Running training with subject_5 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train
Dataset folder crea

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_5. Resharding...
Processing subjects: ['subject_5'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_5/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_5/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 35.9007, Validation Loss: 23.7639, Test Loss: 26.0190
Training RMSE: 34.5768100371908, Validation RMSE: 23.3835, Test RMSE: 24.2499
Training PCC: 0.3657803180592083, Validation PCC: 0.6371, Test PCC: 0.4926
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.6782, Validation Loss: 18.2753, Test Loss: 22.0380
Training RMSE: 20.36182037430035, Validation RMSE: 17.9643, Test RMSE: 21.3300
Training PCC: 0.6633778953684846, Validation PCC: 0.7429, Test PCC: 0.5672
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 17.6617, Validation Loss: 15.8635, Test Loss: 19.3877
Training RMSE: 17.378436716536523, Validation RMSE: 15.6076, Test RMSE: 18.9555
Training PCC: 0.7554545667180963, Validation PCC: 0.7968, Test PCC: 0.6006
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.9795, Validation Loss: 14.2787, Test Loss: 20.6842
Training RMSE: 15.72757270535302, Validation RMSE: 14.0493, Test RMSE: 20.1537
Training PCC: 0.810573500807506, Validation PCC: 0.8558, Test PCC: 0.5767
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.4741, Validation Loss: 13.7504, Test Loss: 21.6459
Training RMSE: 14.183569701531482, Validation RMSE: 13.3803, Test RMSE: 20.8628
Training PCC: 0.8683490957214733, Validation PCC: 0.8948, Test PCC: 0.5863
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.4330, Validation Loss: 11.7823, Test Loss: 22.7802
Training RMSE: 13.078044191901064, Validation RMSE: 11.4978, Test RMSE: 21.4569
Training PCC: 0.9004240320026811, Validation PCC: 0.9217, Test PCC: 0.5801
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.6184, Validation Loss: 11.0874, Test Loss: 24.0063
Training RMSE: 12.241910310758671, Validation RMSE: 10.7783, Test RMSE: 22.4747
Training PCC: 0.9164322587232508, Validation PCC: 0.9334, Test PCC: 0.5782
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.0328, Validation Loss: 10.6109, Test Loss: 24.9892
Training RMSE: 11.633176622595153, Validation RMSE: 10.3097, Test RMSE: 23.6764
Training PCC: 0.9271290881772322, Validation PCC: 0.9411, Test PCC: 0.5748
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.5591, Validation Loss: 10.0929, Test Loss: 26.3191
Training RMSE: 11.150876333786725, Validation RMSE: 9.7910, Test RMSE: 24.7725
Training PCC: 0.9343223598145282, Validation PCC: 0.9462, Test PCC: 0.5850
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.2573, Validation Loss: 9.6293, Test Loss: 25.7454
Training RMSE: 10.831537110621072, Validation RMSE: 9.3284, Test RMSE: 24.4754
Training PCC: 0.9389790250840652, Validation PCC: 0.9503, Test PCC: 0.5411
Checkpoint saved for epoch 10
Total training time: 1765.76 seconds
loading best model from TeacherModel_RMSELoss_test_subject_5_wl100_ol75_mamba
Test Loss: 25.7454, Test PCC: 0.5411, Test RMSE: 24.4754
Running training with subject_6 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train
Dataset folder created:  /c

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_6. Resharding...
Processing subjects: ['subject_6'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_6/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_6/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 36.3469, Validation Loss: 23.8554, Test Loss: 22.5092
Training RMSE: 34.96504276599146, Validation RMSE: 23.4220, Test RMSE: 21.3744
Training PCC: 0.3649080187037472, Validation PCC: 0.6392, Test PCC: 0.5082
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 22.0513, Validation Loss: 19.3064, Test Loss: 20.8435
Training RMSE: 21.772768455255896, Validation RMSE: 19.0903, Test RMSE: 19.7377
Training PCC: 0.6415650654416122, Validation PCC: 0.6745, Test PCC: 0.5112
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 18.3563, Validation Loss: 16.7407, Test Loss: 19.0574
Training RMSE: 18.074778853423847, Validation RMSE: 16.4574, Test RMSE: 17.8761
Training PCC: 0.735379060873595, Validation PCC: 0.7879, Test PCC: 0.5550
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 16.5137, Validation Loss: 15.0776, Test Loss: 17.5863
Training RMSE: 16.23658907806873, Validation RMSE: 14.7939, Test RMSE: 16.5223
Training PCC: 0.8002772826632748, Validation PCC: 0.8359, Test PCC: 0.6090
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 15.4143, Validation Loss: 14.0738, Test Loss: 16.7679
Training RMSE: 15.10720952828992, Validation RMSE: 13.6970, Test RMSE: 15.9645
Training PCC: 0.8513391994474663, Validation PCC: 0.8856, Test PCC: 0.6713
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.5924, Validation Loss: 12.2636, Test Loss: 16.4847
Training RMSE: 13.29952526779521, Validation RMSE: 11.9446, Test RMSE: 15.3658
Training PCC: 0.8917426685937923, Validation PCC: 0.9111, Test PCC: 0.6971
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 13.2694, Validation Loss: 11.8515, Test Loss: 16.1831
Training RMSE: 12.943993622275272, Validation RMSE: 11.4888, Test RMSE: 15.0510
Training PCC: 0.9076356526485178, Validation PCC: 0.9247, Test PCC: 0.7057
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.1723, Validation Loss: 11.1884, Test Loss: 16.6070
Training RMSE: 11.83859685689965, Validation RMSE: 10.8632, Test RMSE: 15.4863
Training PCC: 0.9213670732185878, Validation PCC: 0.9342, Test PCC: 0.7108
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.5862, Validation Loss: 10.6655, Test Loss: 15.5586
Training RMSE: 11.26400050521571, Validation RMSE: 10.3231, Test RMSE: 14.4639
Training PCC: 0.9296531322271312, Validation PCC: 0.9400, Test PCC: 0.7159
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.1807, Validation Loss: 10.2880, Test Loss: 15.6820
Training RMSE: 10.840308410828833, Validation RMSE: 9.9223, Test RMSE: 14.7018
Training PCC: 0.935828254560651, Validation PCC: 0.9465, Test PCC: 0.7243
Checkpoint saved for epoch 10
Total training time: 1776.54 seconds
loading best model from TeacherModel_RMSELoss_test_subject_6_wl100_ol75_mamba
Test Loss: 15.6820, Test PCC: 0.7243, Test RMSE: 14.7018
Running training with subject_7 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train
Dataset folder created:  /c

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_7. Resharding...
Processing subjects: ['subject_7'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_7/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_7/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 38.2305, Validation Loss: 23.4376, Test Loss: 25.1945
Training RMSE: 36.82432382684872, Validation RMSE: 23.0433, Test RMSE: 24.3794
Training PCC: 0.3781980684746975, Validation PCC: 0.6237, Test PCC: 0.3148
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.9915, Validation Loss: 18.8563, Test Loss: 20.2668
Training RMSE: 20.709973384904124, Validation RMSE: 18.5051, Test RMSE: 19.7250
Training PCC: 0.6795785698428295, Validation PCC: 0.7492, Test PCC: 0.4700
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 18.0681, Validation Loss: 17.3898, Test Loss: 18.9632
Training RMSE: 17.74472762566541, Validation RMSE: 17.0345, Test RMSE: 18.3587
Training PCC: 0.7769519953592076, Validation PCC: 0.8112, Test PCC: 0.5253
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 16.6092, Validation Loss: 15.8566, Test Loss: 19.3235
Training RMSE: 16.277811039252423, Validation RMSE: 15.4356, Test RMSE: 18.5363
Training PCC: 0.8198968084009016, Validation PCC: 0.8424, Test PCC: 0.5482
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 15.4164, Validation Loss: 14.6204, Test Loss: 17.7110
Training RMSE: 15.092611446112322, Validation RMSE: 14.3020, Test RMSE: 17.0998
Training PCC: 0.8532208792736705, Validation PCC: 0.8733, Test PCC: 0.5868
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 14.1949, Validation Loss: 13.1240, Test Loss: 17.7248
Training RMSE: 13.870961805084258, Validation RMSE: 12.7232, Test RMSE: 16.7912
Training PCC: 0.884037214296764, Validation PCC: 0.9036, Test PCC: 0.6007
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 13.5222, Validation Loss: 12.1121, Test Loss: 17.5322
Training RMSE: 13.122951092381458, Validation RMSE: 11.6737, Test RMSE: 16.4654
Training PCC: 0.9051616105528124, Validation PCC: 0.9227, Test PCC: 0.6539
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.9457, Validation Loss: 11.6276, Test Loss: 17.7332
Training RMSE: 12.57362302945907, Validation RMSE: 11.1631, Test RMSE: 16.6227
Training PCC: 0.9169066280760129, Validation PCC: 0.9311, Test PCC: 0.6553
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 12.0707, Validation Loss: 11.3228, Test Loss: 16.7923
Training RMSE: 11.684170325856071, Validation RMSE: 10.8445, Test RMSE: 15.6901
Training PCC: 0.9267793334527038, Validation PCC: 0.9381, Test PCC: 0.6752
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.4934, Validation Loss: 11.1984, Test Loss: 16.9378
Training RMSE: 11.098828448421807, Validation RMSE: 10.7285, Test RMSE: 15.8326
Training PCC: 0.9345019262882772, Validation PCC: 0.9425, Test PCC: 0.6829
Checkpoint saved for epoch 10
Total training time: 1770.70 seconds
loading best model from TeacherModel_RMSELoss_test_subject_7_wl100_ol75_mamba
Test Loss: 16.9378, Test PCC: 0.6829, Test RMSE: 15.8326
Running training with subject_8 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train
Dataset folder created:  

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_8. Resharding...
Processing subjects: ['subject_8'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_8/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_8/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 38.3002, Validation Loss: 24.4932, Test Loss: 23.4383
Training RMSE: 36.7922957329609, Validation RMSE: 23.9629, Test RMSE: 22.6670
Training PCC: 0.3447861949254809, Validation PCC: 0.6459, Test PCC: 0.4042
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 21.7640, Validation Loss: 19.2191, Test Loss: 17.3520
Training RMSE: 21.4709741871071, Validation RMSE: 18.9484, Test RMSE: 16.8105
Training PCC: 0.6492804315219481, Validation PCC: 0.7282, Test PCC: 0.5199
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 18.3719, Validation Loss: 17.0840, Test Loss: 16.7247
Training RMSE: 18.060593396545684, Validation RMSE: 16.7946, Test RMSE: 16.0020
Training PCC: 0.7630002473987944, Validation PCC: 0.8042, Test PCC: 0.5495
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 16.3697, Validation Loss: 14.7132, Test Loss: 15.7186
Training RMSE: 16.073564962641203, Validation RMSE: 14.3469, Test RMSE: 14.7831
Training PCC: 0.8249314108704802, Validation PCC: 0.8686, Test PCC: 0.6442
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.7334, Validation Loss: 13.6554, Test Loss: 13.4279
Training RMSE: 14.371683783536726, Validation RMSE: 13.3251, Test RMSE: 12.8828
Training PCC: 0.8785343199056356, Validation PCC: 0.8996, Test PCC: 0.7618
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.5970, Validation Loss: 12.7649, Test Loss: 14.6437
Training RMSE: 13.210864071759858, Validation RMSE: 12.3048, Test RMSE: 14.0035
Training PCC: 0.9029424995226462, Validation PCC: 0.9207, Test PCC: 0.7763
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.7348, Validation Loss: 28.9040, Test Loss: 13.7736
Training RMSE: 12.350990829241033, Validation RMSE: 27.0229, Test RMSE: 13.3744
Training PCC: 0.9178569305165444, Validation PCC: 0.9013, Test PCC: 0.7948
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 18.3793, Validation Loss: 11.0239, Test Loss: 14.8206
Training RMSE: 17.54565760274659, Validation RMSE: 10.5937, Test RMSE: 14.4475
Training PCC: 0.9225943734221077, Validation PCC: 0.9381, Test PCC: 0.7983
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.8929, Validation Loss: 11.0694, Test Loss: 15.2136
Training RMSE: 11.462954012754693, Validation RMSE: 10.6862, Test RMSE: 14.8544
Training PCC: 0.9321452121557443, Validation PCC: 0.9427, Test PCC: 0.8085
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.2021, Validation Loss: 10.2940, Test Loss: 15.0975
Training RMSE: 10.816972369715645, Validation RMSE: 9.8482, Test RMSE: 14.6636
Training PCC: 0.939564756526659, Validation PCC: 0.9480, Test PCC: 0.8042
Checkpoint saved for epoch 10
Total training time: 1745.64 seconds
loading best model from TeacherModel_RMSELoss_test_subject_8_wl100_ol75_mamba
Test Loss: 15.0975, Test PCC: 0.8042, Test RMSE: 14.6636
Running training with subject_9 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train
Dataset folder created:  /c

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_9. Resharding...
Processing subjects: ['subject_9'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_9/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_9/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 37.0022, Validation Loss: 23.6106, Test Loss: 20.9282
Training RMSE: 35.53211277638994, Validation RMSE: 23.2547, Test RMSE: 19.0710
Training PCC: 0.38550528877691265, Validation PCC: 0.6411, Test PCC: 0.3553
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 21.3579, Validation Loss: 18.8097, Test Loss: 15.7007
Training RMSE: 21.052778344949, Validation RMSE: 18.5855, Test RMSE: 14.8070
Training PCC: 0.6702140022882913, Validation PCC: 0.7457, Test PCC: 0.6027
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 17.9417, Validation Loss: 16.4803, Test Loss: 15.2120
Training RMSE: 17.634696815361533, Validation RMSE: 16.2344, Test RMSE: 14.6137
Training PCC: 0.7747354112202699, Validation PCC: 0.8204, Test PCC: 0.6329
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.9773, Validation Loss: 14.4835, Test Loss: 14.0298
Training RMSE: 15.668191597276838, Validation RMSE: 14.2050, Test RMSE: 13.4265
Training PCC: 0.8431043437200069, Validation PCC: 0.8773, Test PCC: 0.6780
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.5126, Validation Loss: 13.2385, Test Loss: 13.4939
Training RMSE: 14.156882297507785, Validation RMSE: 12.9617, Test RMSE: 12.7644
Training PCC: 0.8844930600988902, Validation PCC: 0.9037, Test PCC: 0.7354
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.5959, Validation Loss: 12.3856, Test Loss: 11.9358
Training RMSE: 13.216862272175248, Validation RMSE: 12.0874, Test RMSE: 11.3715
Training PCC: 0.9038268428466628, Validation PCC: 0.9182, Test PCC: 0.7706
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 13.0262, Validation Loss: 12.0712, Test Loss: 11.9335
Training RMSE: 12.606823767446096, Validation RMSE: 11.7262, Test RMSE: 11.4170
Training PCC: 0.9148332155217646, Validation PCC: 0.9251, Test PCC: 0.7983
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.2970, Validation Loss: 11.2300, Test Loss: 12.4491
Training RMSE: 11.892214483333143, Validation RMSE: 10.8787, Test RMSE: 11.9589
Training PCC: 0.9260125187325862, Validation PCC: 0.9367, Test PCC: 0.7685
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.9091, Validation Loss: 21.2090, Test Loss: 12.0497
Training RMSE: 11.503253189455839, Validation RMSE: 20.5038, Test RMSE: 11.5295
Training PCC: 0.9308062593930241, Validation PCC: 0.9069, Test PCC: 0.7720
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.4542, Validation Loss: 10.8218, Test Loss: 11.4845
Training RMSE: 11.037696458290107, Validation RMSE: 10.4935, Test RMSE: 11.0836
Training PCC: 0.9380676631626171, Validation PCC: 0.9433, Test PCC: 0.7775
Checkpoint saved for epoch 10
Total training time: 1737.93 seconds
loading best model from TeacherModel_RMSELoss_test_subject_9_wl100_ol75_mamba
Test Loss: 11.4845, Test PCC: 0.7775, Test RMSE: 11.0836
Running training with subject_10 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train
Dataset folder created:  /

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_10. Resharding...
Processing subjects: ['subject_10'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_10/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_10/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 38.1629, Validation Loss: 23.3561, Test Loss: 23.1034
Training RMSE: 36.77563712824955, Validation RMSE: 22.9881, Test RMSE: 22.5940
Training PCC: 0.3775764628739143, Validation PCC: 0.6441, Test PCC: 0.4391
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.7791, Validation Loss: 17.9284, Test Loss: 19.8667
Training RMSE: 20.518541806537083, Validation RMSE: 17.6636, Test RMSE: 19.3892
Training PCC: 0.671223266063397, Validation PCC: 0.7613, Test PCC: 0.5590
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 17.5690, Validation Loss: 16.1683, Test Loss: 20.2902
Training RMSE: 17.301850677998736, Validation RMSE: 15.8757, Test RMSE: 19.7986
Training PCC: 0.773543385190549, Validation PCC: 0.8058, Test PCC: 0.5592
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 16.0248, Validation Loss: 15.0821, Test Loss: 19.7856
Training RMSE: 15.75048734351636, Validation RMSE: 14.7503, Test RMSE: 18.9721
Training PCC: 0.8180943068135832, Validation PCC: 0.8588, Test PCC: 0.5526
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.6532, Validation Loss: 13.2457, Test Loss: 19.1735
Training RMSE: 14.358716570249038, Validation RMSE: 12.8872, Test RMSE: 18.8994
Training PCC: 0.867714465282592, Validation PCC: 0.8990, Test PCC: 0.5677
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.4709, Validation Loss: 12.2090, Test Loss: 18.2245
Training RMSE: 13.12710910385699, Validation RMSE: 11.7852, Test RMSE: 17.9267
Training PCC: 0.9007031288557226, Validation PCC: 0.9227, Test PCC: 0.5875
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.6264, Validation Loss: 11.7389, Test Loss: 17.8300
Training RMSE: 12.245844853400726, Validation RMSE: 11.3514, Test RMSE: 17.4936
Training PCC: 0.9190557598239293, Validation PCC: 0.9334, Test PCC: 0.6087
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 11.9508, Validation Loss: 11.0055, Test Loss: 18.8176
Training RMSE: 11.55119631253997, Validation RMSE: 10.5395, Test RMSE: 18.5578
Training PCC: 0.9297220508860384, Validation PCC: 0.9420, Test PCC: 0.5962
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.5416, Validation Loss: 10.8722, Test Loss: 18.5727
Training RMSE: 11.12908682055088, Validation RMSE: 10.3422, Test RMSE: 18.2318
Training PCC: 0.9361245231123565, Validation PCC: 0.9469, Test PCC: 0.5919
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.1718, Validation Loss: 10.3920, Test Loss: 18.1186
Training RMSE: 10.759744566883763, Validation RMSE: 9.8960, Test RMSE: 17.7869
Training PCC: 0.9406839763033531, Validation PCC: 0.9513, Test PCC: 0.6050
Checkpoint saved for epoch 10
Total training time: 1761.35 seconds
loading best model from TeacherModel_RMSELoss_test_subject_10_wl100_ol75_mamba
Test Loss: 18.1186, Test PCC: 0.6050, Test RMSE: 17.7869
Running training with subject_11 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train
Dataset folder created:  /

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_11. Resharding...
Processing subjects: ['subject_11'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_11/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_11/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 41.6985, Validation Loss: 23.7713, Test Loss: 26.6493
Training RMSE: 40.3505991087954, Validation RMSE: 23.3771, Test RMSE: 25.5200
Training PCC: 0.3405850658480973, Validation PCC: 0.6137, Test PCC: 0.4389
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 21.0178, Validation Loss: 18.6135, Test Loss: 21.3872
Training RMSE: 20.78194489726734, Validation RMSE: 18.3758, Test RMSE: 20.4264
Training PCC: 0.6515237003169999, Validation PCC: 0.7177, Test PCC: 0.5037
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 18.0387, Validation Loss: 16.8405, Test Loss: 20.1719
Training RMSE: 17.803047343443925, Validation RMSE: 16.5565, Test RMSE: 18.9414
Training PCC: 0.7521864942805818, Validation PCC: 0.7960, Test PCC: 0.5526
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 16.1821, Validation Loss: 15.0875, Test Loss: 24.2603
Training RMSE: 15.921692505749396, Validation RMSE: 14.8141, Test RMSE: 23.3259
Training PCC: 0.8135246118082122, Validation PCC: 0.8450, Test PCC: 0.5767
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.7695, Validation Loss: 13.2109, Test Loss: 16.0496
Training RMSE: 14.472786896068044, Validation RMSE: 12.9181, Test RMSE: 14.9071
Training PCC: 0.8615479531629822, Validation PCC: 0.8898, Test PCC: 0.6758
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.4393, Validation Loss: 12.3974, Test Loss: 15.5061
Training RMSE: 13.102501815112383, Validation RMSE: 12.0519, Test RMSE: 14.6291
Training PCC: 0.8972765177333869, Validation PCC: 0.9166, Test PCC: 0.6952
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.6787, Validation Loss: 11.5307, Test Loss: 15.1975
Training RMSE: 12.303055182665744, Validation RMSE: 11.1842, Test RMSE: 14.2980
Training PCC: 0.914799651661886, Validation PCC: 0.9286, Test PCC: 0.7146
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.0792, Validation Loss: 11.3478, Test Loss: 14.9634
Training RMSE: 11.687877641471479, Validation RMSE: 10.9043, Test RMSE: 14.1033
Training PCC: 0.9248244833921931, Validation PCC: 0.9360, Test PCC: 0.7156
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.6925, Validation Loss: 10.5935, Test Loss: 15.4781
Training RMSE: 11.289360927545518, Validation RMSE: 10.2233, Test RMSE: 14.6116
Training PCC: 0.9316413932525832, Validation PCC: 0.9439, Test PCC: 0.7141
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.2280, Validation Loss: 9.9319, Test Loss: 15.9111
Training RMSE: 10.8329483870013, Validation RMSE: 9.5706, Test RMSE: 14.9396
Training PCC: 0.9368492846673712, Validation PCC: 0.9488, Test PCC: 0.7073
Checkpoint saved for epoch 10
Total training time: 1774.58 seconds
loading best model from TeacherModel_RMSELoss_test_subject_11_wl100_ol75_mamba
Test Loss: 15.9111, Test PCC: 0.7073, Test RMSE: 14.9396
Running training with subject_12 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train
Dataset folder created:  /con

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_12. Resharding...
Processing subjects: ['subject_12'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_12/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_12/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 37.4387, Validation Loss: 23.6546, Test Loss: 22.1963
Training RMSE: 36.05895134169298, Validation RMSE: 23.2837, Test RMSE: 20.7451
Training PCC: 0.3570841837427982, Validation PCC: 0.6172, Test PCC: 0.4665
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 21.2885, Validation Loss: 18.6920, Test Loss: 16.4469
Training RMSE: 21.00434717779682, Validation RMSE: 18.3518, Test RMSE: 15.7157
Training PCC: 0.665580429757594, Validation PCC: 0.7472, Test PCC: 0.5402
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 18.1134, Validation Loss: 16.8424, Test Loss: 16.0081
Training RMSE: 17.80548354513516, Validation RMSE: 16.5104, Test RMSE: 15.2695
Training PCC: 0.7609851980111867, Validation PCC: 0.7917, Test PCC: 0.5726
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 16.6801, Validation Loss: 23.4467, Test Loss: 14.7281
Training RMSE: 16.39438933251942, Validation RMSE: 22.4648, Test RMSE: 14.0556
Training PCC: 0.7958016531633009, Validation PCC: 0.7916, Test PCC: 0.6414
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 15.6556, Validation Loss: 14.3835, Test Loss: 14.2602
Training RMSE: 15.359756831005079, Validation RMSE: 14.0457, Test RMSE: 13.5856
Training PCC: 0.8427188542918591, Validation PCC: 0.8830, Test PCC: 0.6991
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 14.0465, Validation Loss: 12.5810, Test Loss: 14.1634
Training RMSE: 13.705952070232359, Validation RMSE: 12.1709, Test RMSE: 13.5714
Training PCC: 0.8904320886068161, Validation PCC: 0.9153, Test PCC: 0.7109
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 13.1237, Validation Loss: 11.8616, Test Loss: 13.5939
Training RMSE: 12.740114729640771, Validation RMSE: 11.4256, Test RMSE: 13.0947
Training PCC: 0.9116229249385963, Validation PCC: 0.9291, Test PCC: 0.6922
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.4012, Validation Loss: 11.6026, Test Loss: 14.5187
Training RMSE: 12.009852790212067, Validation RMSE: 11.1270, Test RMSE: 13.8432
Training PCC: 0.9242092787870725, Validation PCC: 0.9365, Test PCC: 0.7002
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.9729, Validation Loss: 10.9589, Test Loss: 13.9916
Training RMSE: 11.568894886398148, Validation RMSE: 10.5097, Test RMSE: 13.4280
Training PCC: 0.9306647798178286, Validation PCC: 0.9441, Test PCC: 0.6946
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.5562, Validation Loss: 10.7060, Test Loss: 15.0307
Training RMSE: 11.153470206412322, Validation RMSE: 10.1762, Test RMSE: 14.3435
Training PCC: 0.9365207238542497, Validation PCC: 0.9478, Test PCC: 0.7187
Checkpoint saved for epoch 10
Total training time: 1752.01 seconds
loading best model from TeacherModel_RMSELoss_test_subject_12_wl100_ol75_mamba
Test Loss: 15.0307, Test PCC: 0.7187, Test RMSE: 14.3435
Running training with subject_13 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_mamba
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train
Dataset folder created:  

Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_13. Resharding...
Processing subjects: ['subject_13'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_13/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_13/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_mamba
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 35.7650, Validation Loss: 23.1349, Test Loss: 30.0539
Training RMSE: 34.420508072571764, Validation RMSE: 22.7806, Test RMSE: 28.9029
Training PCC: 0.36608809840419587, Validation PCC: 0.6268, Test PCC: 0.4053
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 20.5301, Validation Loss: 17.8745, Test Loss: 27.8296
Training RMSE: 20.265936674159246, Validation RMSE: 17.5257, Test RMSE: 25.9397
Training PCC: 0.6799787383265296, Validation PCC: 0.7639, Test PCC: 0.4882
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 17.6696, Validation Loss: 16.2733, Test Loss: 26.3458
Training RMSE: 17.36173055216418, Validation RMSE: 15.9490, Test RMSE: 24.8622
Training PCC: 0.7834008410026888, Validation PCC: 0.8224, Test PCC: 0.5500
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 15.9677, Validation Loss: 14.9277, Test Loss: 27.2655
Training RMSE: 15.669712273451793, Validation RMSE: 14.5468, Test RMSE: 25.3109
Training PCC: 0.8358888805288655, Validation PCC: 0.8657, Test PCC: 0.5942
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 14.7368, Validation Loss: 13.2854, Test Loss: 27.3239
Training RMSE: 14.360522323303021, Validation RMSE: 12.8759, Test RMSE: 25.1061
Training PCC: 0.8758726129638021, Validation PCC: 0.9014, Test PCC: 0.6372
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 13.3475, Validation Loss: 12.3222, Test Loss: 25.9519
Training RMSE: 12.970876497423701, Validation RMSE: 11.8974, Test RMSE: 23.6860
Training PCC: 0.9047894520293912, Validation PCC: 0.9192, Test PCC: 0.6772
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 12.6080, Validation Loss: 11.4862, Test Loss: 23.8163
Training RMSE: 12.202577706075196, Validation RMSE: 11.0016, Test RMSE: 21.8434
Training PCC: 0.9185415947948027, Validation PCC: 0.9315, Test PCC: 0.6877
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 12.0163, Validation Loss: 11.1613, Test Loss: 23.2696
Training RMSE: 11.58988218053438, Validation RMSE: 10.6934, Test RMSE: 21.2000
Training PCC: 0.9287241634208591, Validation PCC: 0.9377, Test PCC: 0.7114
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 11.5894, Validation Loss: 11.2205, Test Loss: 22.0443
Training RMSE: 11.148955587412921, Validation RMSE: 10.6020, Test RMSE: 20.1999
Training PCC: 0.9351430822571332, Validation PCC: 0.9435, Test PCC: 0.7123
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 11.1658, Validation Loss: 10.3496, Test Loss: 19.5847
Training RMSE: 10.726892519220462, Validation RMSE: 9.8808, Test RMSE: 17.9465
Training PCC: 0.9402756988530436, Validation PCC: 0.9472, Test PCC: 0.7374
Checkpoint saved for epoch 10
Total training time: 1748.69 seconds
loading best model from TeacherModel_RMSELoss_test_subject_13_wl100_ol75_mamba
Test Loss: 19.5847, Test PCC: 0.7374, Test RMSE: 17.9465


In [189]:

average_best_rmse = np.mean(best_rmse_per_subject)
average_best_pcc = np.mean(best_pcc_per_subject)
print(f"Average of best RMSEs across all subjects: {average_best_rmse:.4f}")
print(f"Average of best PCCs across all subjects: {average_best_pcc:.4f}")
print(best_rmse_per_subject)
print(best_pcc_per_subject)

# subjects = [f'Subject {i+1}' for i in range(len(best_rmse_per_subject))]

# print(best_rmse_per_subject)
# # Plot a bar chart with subject labels on the x-axis
# plt.figure(figsize=(10, 6))
# plt.bar(subjects, best_rmse_per_subject, color='blue', edgecolor='black')
# plt.title('Best RMSEs for Each Subject')
# plt.xlabel('Subjects')
# plt.ylabel('Best RMSE')
# plt.xticks(rotation=45, ha='right')
# plt.grid(True, axis='y')
# plt.tight_layout()
# plt.show()

Average of best RMSEs across all subjects: 20.6979
Average of best PCCs across all subjects: 0.6945
[17.580608719962914, 22.43028281550042, 17.252830430435836, 66.034909504385, 24.47543561459504, 14.701816393190448, 15.832611321906532, 14.663643206914633, 11.083609932166619, 17.786915787786402, 14.939624728052642, 14.343490246365574, 17.94648080859403]
[0.6765217353521571, 0.6770613222830497, 0.769537046543718, 0.6067476795295964, 0.5410608806742143, 0.7242940039982152, 0.6829414836821375, 0.8041962350472401, 0.7774609055625277, 0.6050390339712423, 0.7073031036573799, 0.7186605077326216, 0.7373964854966407]


In [190]:
import os
import zipfile
from datetime import datetime

notebook_name = 'regression_benchmark_mamba'

# Create a timestamped folder name based on the notebook name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"{notebook_name}_checkpoints_{timestamp}"

# Make sure the folder exists
os.makedirs(folder_name, exist_ok=True)

checkpoint_dir = '.'

# Zip all checkpoint files and save in the new folder
zip_filename = f"{folder_name}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # List files only in the current directory (no subfolders)
    for file in os.listdir(checkpoint_dir):
        if "TeacherModel" in str(file):
          file_path = os.path.join(checkpoint_dir, file)
          zipf.write(file_path, os.path.relpath(file_path, checkpoint_dir))
          print(f"Checkpoint {file} has been added to the zip file.")
print(f"All checkpoints have been zipped and saved as {zip_filename}.")




Checkpoint TeacherModel_RMSELoss_test_subject_11_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_3_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_13_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_12_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_7_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_9_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_5_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_6_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_10_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_8_wl100_ol75_mamba has been added to the zip file.
Checkpoint TeacherModel_RM

In [191]:
# Download the zip file to your local machine
from google.colab import files
files.download(zip_filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>