In [1]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Mounted at /content/MyDrive


In [2]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data_final.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [3]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
      extracted_data = []
      new_column_names = []  # Initialize here

      if isinstance(session_data, h5py.Dataset):
          if session_data.dtype.names:
              # Compound dataset
              column_names = session_data.dtype.names
              for channel in channels:
                  if channel in column_names:
                      channel_data = session_data[channel][:]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)  # Populate here
                  else:
                      print(f"Channel {channel} not found in compound dataset.")
          else:
              # Simple dataset
              column_names = list(session_data.attrs.get('column_names', []))
              assert len(column_names) > 0, "column_names not found in dataset attributes"
              for channel in channels:
                  if channel in column_names:
                      col_idx = column_names.index(channel)
                      channel_data = session_data[:, col_idx]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)
                  else:
                      print(f"Channel {channel} not found in session data.")

      return np.array(extracted_data), new_column_names


In [4]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

class ImuJointPairSubjectDataset(ImuJointPairDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Create a mapping from subject strings (e.g., 'subject_1') to class indices
        self.subject_mapping = {subject: i for i, subject in enumerate(sorted(subjects))}

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx, 0]

        # Extract the subject ID from the filename
        filename_base = os.path.basename(filename)
        filename_without_ext = os.path.splitext(filename_base)[0]
        parts = filename_without_ext.split('_')

        # Construct subject string in the format 'subject_x'
        try:
            subject_index = parts.index('subject')
            subject_str = f"subject_{parts[subject_index + 1]}"  # Create string like 'subject_1'
        except ValueError:
            raise ValueError(f"'subject' not found in filename: {filename}")
        except IndexError:
            raise ValueError(f"Subject ID not found after 'subject' in filename: {filename}")

        # Map subject_str to class index
        if subject_str not in self.subject_mapping:
            raise ValueError(f"Subject ID {subject_str} not found in training set.")

        mapped_class = self.subject_mapping[subject_str]

        # Return class index instead of one-hot encoding
        return imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class

class ImuJointPairSubjectNormalizedDataset(ImuJointPairSubjectDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Compute normalization statistics per subject for IMU and EMG data
        self.normalization_stats = {subject: {'imu_acc': [], 'imu_gyr': [], 'emg': []} for subject in subjects}

        # Loop through the dataset once to gather data for each subject
        for idx in range(len(self.data)):
            filename = self.data.iloc[idx]['file_name']
            subject_str = next((subject for subject in subjects if subject in filename), None)
            if subject_str:
                imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class = super().__getitem__(idx)
                self.normalization_stats[subject_str]['imu_acc'].append(imu_data_acc)
                self.normalization_stats[subject_str]['imu_gyr'].append(imu_data_gyr)
                self.normalization_stats[subject_str]['emg'].append(emg_data)

        # Compute mean and std for each subject
        for subject, data in self.normalization_stats.items():
            if data['imu_acc']:
                imu_acc_data = torch.stack(data['imu_acc'])
                imu_gyr_data = torch.stack(data['imu_gyr'])
                emg_data = torch.stack(data['emg'])

                self.normalization_stats[subject] = {
                    'imu_acc_mean': imu_acc_data.mean(dim=0),
                    'imu_acc_std': imu_acc_data.std(dim=0),
                    'imu_gyr_mean': imu_gyr_data.mean(dim=0),
                    'imu_gyr_std': imu_gyr_data.std(dim=0),
                    'emg_mean': emg_data.mean(dim=0),
                    'emg_std': emg_data.std(dim=0)
                }

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx]['file_name']

        # Extract the subject ID from the filename
        subject_str = next((subject for subject in self.normalization_stats.keys() if subject in filename), None)
        if not subject_str:
            raise ValueError(f"Normalization stats not found for subject in filename: {filename}")

        # Apply normalization for each subject separately
        stats = self.normalization_stats[subject_str]
        imu_data_acc = (imu_data_acc - stats['imu_acc_mean']) / (stats['imu_acc_std'] + 1e-8)
        imu_data_gyr = (imu_data_gyr - stats['imu_gyr_mean']) / (stats['imu_gyr_std'] + 1e-8)
        emg_data = (emg_data - stats['emg_mean']) / (stats['emg_std'] + 1e-8)

        # Return normalized IMU data, joint data, EMG data, and class index
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

def create_normbysub_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairSubjectNormalizedDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairSubjectNormalizedDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader




In [5]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used
class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100):
        super(teacher, self).__init__()

        self.w=w
        self.encoder_1_acc=Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr=Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg=Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc=Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr=Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg=Encoder_2(input_emg, drop_prob)

        self.BN_acc= nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr= nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg= nn.BatchNorm1d(input_emg, affine=False)


        self.fc = nn.Linear(2*3*128+128,3)
        self.dropout=nn.Dropout(p=.05)

        self.gate_1=GatingModule(128)
        self.gate_2=GatingModule(128)
        self.gate_3=GatingModule(128)

        self.fc_kd = nn.Linear(3*128, 2*128)

               # Define the gating network
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention=nn.MultiheadAttention(3*128,4,batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128*3, 3*128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2*3*128+128, 2*3*128+128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)


    def forward(self, x_acc, x_gyr, x_emg):

        x_acc_1=x_acc.view(x_acc.size(0)*x_acc.size(1),x_acc.size(-1))
        x_gyr_1=x_gyr.view(x_gyr.size(0)*x_gyr.size(1),x_gyr.size(-1))
        x_emg_1=x_emg.view(x_emg.size(0)*x_emg.size(1),x_emg.size(-1))

        x_acc_1=self.BN_acc(x_acc_1)
        x_gyr_1=self.BN_gyr(x_gyr_1)
        x_emg_1=self.BN_emg(x_emg_1)

        x_acc_2=x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2=x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2=x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Pass through Encoder 1 for each modality and capture hidden states
        x_acc_1, (h_acc_1, _) = self.encoder_1_acc(x_acc_2)
        x_gyr_1, (h_gyr_1, _) = self.encoder_1_gyr(x_gyr_2)
        x_emg_1, (h_emg_1, _) = self.encoder_1_emg(x_emg_2)

        # Pass through Encoder 2 for each modality and capture hidden states
        x_acc_2, (h_acc_2, _) = self.encoder_2_acc(x_acc_2)
        x_gyr_2, (h_gyr_2, _) = self.encoder_2_gyr(x_gyr_2)
        x_emg_2, (h_emg_2, _) = self.encoder_2_emg(x_emg_2)

        # x_acc=torch.cat((x_acc_1,x_acc_2),dim=-1)
        # x_gyr=torch.cat((x_gyr_1,x_gyr_2),dim=-1)
        # x_emg=torch.cat((x_emg_1,x_emg_2),dim=-1)

        x_acc=self.gate_1(x_acc_1,x_acc_2)
        x_gyr=self.gate_2(x_gyr_1,x_gyr_2)
        x_emg=self.gate_3(x_emg_1,x_emg_2)

        x=torch.cat((x_acc,x_gyr,x_emg),dim=-1)
        x_kd=self.fc_kd(x)


        out_1, attn_output_weights=self.attention(x,x,x)

        gating_weights = self.gating_net(x)
        out_2=gating_weights*x

        weights_1 = self.weighted_feat(x[:,:,0:128])
        weights_2 = self.weighted_feat(x[:,:,128:2*128])
        weights_3 = self.weighted_feat(x[:,:,2*128:3*128])
        x_1=weights_1*x[:,:,0:128]
        x_2=weights_2*x[:,:,128:2*128]
        x_3=weights_3*x[:,:,2*128:3*128]
        out_3=x_1+x_2+x_3

        out=torch.cat((out_1,out_2,out_3),dim=-1)

        gating_weights_1 = self.gating_net_1(out)
        out=gating_weights_1*out

        out=self.fc(out)

        #print(out.shape)
        return out, x_kd, (h_acc_1, h_acc_2, h_gyr_1, h_gyr_2, h_emg_1, h_emg_2)




In [6]:
# @title Loss Functions
import statistics

class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

In [7]:





# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse



def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(device, train_loader, val_loader, test_loader, learn_rate, epochs, model, filename, loss_function, optimizer=None, l1_lambda=None, train_from_last_epoch=False, curriculum_loader=None):
    model.to(device)
    criterion = loss_function

    if optimizer is None:
        # Create a default Adam optimizer if none is passed
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    train_pccs_channelwise = []
    val_pccs_channelwise = []
    test_pccs_channelwise = []

    train_rmses_channelwise = []
    val_rmses_channelwise = []
    test_rmses_channelwise = []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    if train_from_last_epoch and os.path.exists(checkpoint_path):
        # Scan for the latest saved checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch number
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']  # Continue from the next epoch

            # Load the history from checkpoint
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
            train_pccs = checkpoint['history']['train_pccs']
            val_pccs = checkpoint['history']['val_pccs']
            test_pccs = checkpoint['history']['test_pccs']
            train_rmses = checkpoint['history']['train_rmses']
            val_rmses = checkpoint['history']['val_rmses']
            test_rmses = checkpoint['history']['test_rmses']
            train_pccs_channelwise = checkpoint['history']['train_pccs_channelwise']
            val_pccs_channelwise = checkpoint['history']['val_pccs_channelwise']
            test_pccs_channelwise = checkpoint['history']['test_pccs_channelwise']
            train_rmses_channelwise = checkpoint['history']['train_rmses_channelwise']
            val_rmses_channelwise = checkpoint['history']['val_rmses_channelwise']
            test_rmses_channelwise = checkpoint['history']['test_rmses_channelwise']
            if 'curriculum_schedule' in checkpoint:
                curriculum_loader.curriculum_schedule = checkpoint['curriculum_schedule']  # Load saved curriculum schedule
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(last_epoch, epochs):
        epoch_start_time = time.time()
        model.train()

        if curriculum_loader:
            curriculum_loader.update_epoch(epoch)
            train_loader, val_loader, test_loader = curriculum_loader.get_loaders()
        # Track metrics per channel
        epoch_train_loss = np.zeros(len(config.channels_joints))
        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        # Use epoch starting from `epoch + 1` since we want to reflect actual starting epoch correctly
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            # Ensure inputs are properly sent to device and are of correct type
            output = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())


            # Check if output is a tuple, take the first element if true
            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())

            else:
                loss = criterion(output, target.to(device).float())

            # Compute loss



            # Apply L1 regularization if specified
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss = loss + l1_lambda * l1_norm
            else:
                total_loss = loss

            # Backpropagate the gradients for total_loss
            total_loss.backward()
            optimizer.step()

            # Detach tensors and move to CPU to prevent issues with gradient computation
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)

            # Accumulate loss, pcc, and rmse without modifying in-place
            epoch_train_loss += loss.detach().cpu().numpy()
            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_loss)
        train_pccs.append(np.mean(avg_train_pcc))  # Overall average PCC
        train_rmses.append(np.mean(avg_train_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        train_pccs_channelwise.append(avg_train_pcc)  # Per channel
        train_rmses_channelwise.append(avg_train_rmse)  # Per channel

        # Evaluate on validation set every epoch
        avg_val_loss, avg_val_pcc, avg_val_rmse = evaluate_model(device, model, val_loader, criterion)
        val_losses.append(avg_val_loss)
        val_pccs.append(np.mean(avg_val_pcc))  # Overall average PCC
        val_rmses.append(np.mean(avg_val_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        val_pccs_channelwise.append(avg_val_pcc)  # Per channel
        val_rmses_channelwise.append(avg_val_rmse)  # Per channel

        # Evaluate on test set and checkpoint every epoch
        avg_test_loss, avg_test_pcc, avg_test_rmse = evaluate_model(device, model, test_loader, criterion)
        test_losses.append(avg_test_loss)
        test_pccs.append(np.mean(avg_test_pcc))  # Overall average PCC
        test_rmses.append(np.mean(avg_test_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        test_pccs_channelwise.append(avg_test_pcc)  # Per channel
        test_rmses_channelwise.append(avg_test_rmse)  # Per channel

        print(f"Epoch: {epoch + 1}, Training Loss: {np.mean(avg_train_loss):.4f}, Validation Loss: {np.mean(avg_val_loss):.4f}, Test Loss: {np.mean(avg_test_loss):.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse)}, Validation RMSE: {np.mean(avg_val_rmse):.4f}, Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc)}, Validation PCC: {np.mean(avg_val_pcc):.4f}, Test PCC: {np.mean(avg_test_pcc):.4f}")

        # Save checkpoint, including curriculum schedule
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save checkpoint with the curriculum schedule
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'test_losses': test_losses,
            'train_pccs': train_pccs,
            'val_pccs': val_pccs,
            'test_pccs': test_pccs,
            'train_rmses': train_rmses,
            'val_rmses': val_rmses,
            'test_rmses': test_rmses,
            'train_pccs_channelwise': train_pccs_channelwise,
            'val_pccs_channelwise': val_pccs_channelwise,
            'test_pccs_channelwise': test_pccs_channelwise,
            'train_rmses_channelwise': train_rmses_channelwise,
            'val_rmses_channelwise': val_rmses_channelwise,
            'test_rmses_channelwise': test_rmses_channelwise
        }

        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth",
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            test_loss=avg_test_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            },
            history=history,  # Save history in the checkpoint
            curriculum_schedule=curriculum_loader.curriculum_schedule if curriculum_loader else None # Save curriculum schedule
        )

        # Early stopping logic
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    print(f"loading best model from {filename}")
    model.load_state_dict(torch.load(filename))
    model.eval()
    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses






In [8]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model




In [10]:
import matplotlib.pyplot as plt
import numpy as np
import os
import h5py
from tqdm.notebook import tqdm
import pandas as pd
import csv

all_subjects= [f"subject_{x}" for x in range(1,14)]
input_acc, input_gyr, input_emg = 18,18,3
batch_size = 64

# Placeholder for storing best RMSEs
best_rmse_per_subject = []
best_pcc_per_subject = []

train_flag = True

for test_subject in all_subjects:



    print(f"Running training with {test_subject} as the test subject.")

    # Set up the training subjects (all except the test subject)
    train_subjects = [subject for subject in all_subjects if subject != test_subject]

    model_name = f'TeacherModel_RMSELoss_test_{test_subject}_wl{100}_ol{75}_nbs'
    print(f"Model: {model_name}")

    # Load the model configuration and data loaders
    model_config = {
        'model': create_teacher_model(input_acc, input_gyr, input_emg, w=100),
        'loss': RMSELoss(),
        'loaders': create_normbysub_data_loaders(
            config=config,
            train_subjects=train_subjects,
            test_subjects=[test_subject],
            window_length=100,
            window_overlap=75,
            batch_size=batch_size
        ),
        'epochs': 10,
        'use_curriculum': False
    }

    model = model_config['model']
    loss_function = model_config['loss']
    epochs = model_config.get("epochs", 10)
    device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    learn_rate = model_config.get("learn_rate", 0.001)
    use_curriculum = model_config.get("use_curriculum", False)

    optimizer = model_config.get("optimizer", None)
    l1_lambda = model_config.get("l1_lambda", None)

    print(f"Running model: {model_name}")

    # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
    train_loader, val_loader, test_loader = model_config['loaders']
    if train_flag:
    # Train the model and save only the best based on validation loss
      model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses = train_teacher(
          device=device,
          train_loader=train_loader,
          val_loader=val_loader,
          test_loader=test_loader,
          learn_rate=learn_rate,
          epochs=epochs,
          model=model,
          filename=model_name,
          loss_function=loss_function,
          optimizer=optimizer,
          l1_lambda=l1_lambda,
          train_from_last_epoch=False
      )
    else:
      #load filename as model
      model.load_state_dict(torch.load(f"{model_name}"))
      model.to(device)
      model.eval()

     #run model on test set and record result
    test_loss, test_pcc, test_rmse = evaluate_model(device, model, test_loader, loss_function)
    print(f"Test Loss: {test_loss:.4f}, Test PCC: {np.mean(test_pcc):.4f}, Test RMSE: {np.mean(test_rmse):.4f}")
    best_rmse_per_subject.append(np.mean(test_rmse))
    best_pcc_per_subject.append(np.mean(test_pcc))


# Compute the average of the best RMSEs across all subjects



Running training with subject_1 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_nbs
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Running model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.3837, Validation Loss: 11.4677, Test Loss: 19.6647
Training RMSE: 17.923453913794624, Validation RMSE: 11.1536, Test RMSE: 18.4427
Training PCC: 0.7795172480551508, Validation PCC: 0.9348, Test PCC: 0.7276
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.5535, Validation Loss: 9.6994, Test Loss: 18.7943
Training RMSE: 10.162407036099838, Validation RMSE: 9.4379, Test RMSE: 17.7213
Training PCC: 0.9487944373877856, Validation PCC: 0.9585, Test PCC: 0.7202
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.1339, Validation Loss: 8.1025, Test Loss: 17.6800
Training RMSE: 8.753677084783773, Validation RMSE: 7.7328, Test RMSE: 16.8822
Training PCC: 0.9626632280357367, Validation PCC: 0.9691, Test PCC: 0.7198
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.3693, Validation Loss: 7.8998, Test Loss: 18.5053
Training RMSE: 8.013367149321917, Validation RMSE: 7.5491, Test RMSE: 17.4003
Training PCC: 0.9691295733218949, Validation PCC: 0.9705, Test PCC: 0.6962
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.8663, Validation Loss: 7.2238, Test Loss: 18.2681
Training RMSE: 7.550067661005422, Validation RMSE: 6.8916, Test RMSE: 17.1778
Training PCC: 0.9724744160538696, Validation PCC: 0.9761, Test PCC: 0.7357
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.6078, Validation Loss: 7.9314, Test Loss: 18.4032
Training RMSE: 7.300032620522347, Validation RMSE: 7.4789, Test RMSE: 17.1338
Training PCC: 0.9741477494387233, Validation PCC: 0.9752, Test PCC: 0.6896
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.1507, Validation Loss: 6.5500, Test Loss: 17.5532
Training RMSE: 6.8792543629987515, Validation RMSE: 6.2929, Test RMSE: 16.6427
Training PCC: 0.9769603349707957, Validation PCC: 0.9800, Test PCC: 0.7398
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.8009, Validation Loss: 6.5808, Test Loss: 19.0875
Training RMSE: 6.555846831363518, Validation RMSE: 6.2569, Test RMSE: 18.0657
Training PCC: 0.9791145572331906, Validation PCC: 0.9805, Test PCC: 0.7111
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.4090, Validation Loss: 5.9227, Test Loss: 17.7505
Training RMSE: 6.19981415653326, Validation RMSE: 5.7385, Test RMSE: 16.9196
Training PCC: 0.9809327086968352, Validation PCC: 0.9823, Test PCC: 0.7338
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.1628, Validation Loss: 6.2399, Test Loss: 17.7246
Training RMSE: 5.976288616110426, Validation RMSE: 5.9892, Test RMSE: 16.6325
Training PCC: 0.9820000358039706, Validation PCC: 0.9821, Test PCC: 0.7256
Checkpoint saved for epoch 10
Total training time: 1198.16 seconds
loading best model from TeacherModel_RMSELoss_test_subject_1_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 17.7505, Test PCC: 0.7338, Test RMSE: 16.9196
Running training with subject_2 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_2. Resharding...
Processing subjects: ['subject_2'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_2/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_2/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.1840, Validation Loss: 10.1188, Test Loss: 22.8702
Training RMSE: 16.801580617582896, Validation RMSE: 9.7865, Test RMSE: 20.9691
Training PCC: 0.7932094775891602, Validation PCC: 0.9507, Test PCC: 0.6652
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.1493, Validation Loss: 7.9504, Test Loss: 20.8881
Training RMSE: 8.848384091039984, Validation RMSE: 7.6531, Test RMSE: 19.2781
Training PCC: 0.95958041273005, Validation PCC: 0.9731, Test PCC: 0.6891
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 7.7379, Validation Loss: 7.1508, Test Loss: 21.4151
Training RMSE: 7.423291139970949, Validation RMSE: 6.7470, Test RMSE: 19.8458
Training PCC: 0.9720422810193533, Validation PCC: 0.9777, Test PCC: 0.6849
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.1622, Validation Loss: 6.4580, Test Loss: 21.8066
Training RMSE: 6.854311111981307, Validation RMSE: 6.0785, Test RMSE: 20.2106
Training PCC: 0.9768110764378052, Validation PCC: 0.9817, Test PCC: 0.6992
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 6.6053, Validation Loss: 6.1411, Test Loss: 21.6047
Training RMSE: 6.31286633886942, Validation RMSE: 5.6671, Test RMSE: 20.0994
Training PCC: 0.9799659804893704, Validation PCC: 0.9835, Test PCC: 0.7008
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 6.1764, Validation Loss: 5.7359, Test Loss: 22.1485
Training RMSE: 5.914845664569033, Validation RMSE: 5.3688, Test RMSE: 20.5985
Training PCC: 0.9820889696751268, Validation PCC: 0.9860, Test PCC: 0.6855
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 5.9132, Validation Loss: 5.7468, Test Loss: 21.7536
Training RMSE: 5.678296060581516, Validation RMSE: 5.4147, Test RMSE: 20.1526
Training PCC: 0.9834814681930326, Validation PCC: 0.9858, Test PCC: 0.6870
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 5.6623, Validation Loss: 5.5181, Test Loss: 21.5547
Training RMSE: 5.4310058047616385, Validation RMSE: 5.2161, Test RMSE: 19.9207
Training PCC: 0.9846942238819661, Validation PCC: 0.9867, Test PCC: 0.7032
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 5.5564, Validation Loss: 5.2828, Test Loss: 21.7202
Training RMSE: 5.331639875968297, Validation RMSE: 4.9321, Test RMSE: 20.1051
Training PCC: 0.9853066003585994, Validation PCC: 0.9873, Test PCC: 0.6838
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.3389, Validation Loss: 4.9161, Test Loss: 21.4450
Training RMSE: 5.135647154920469, Validation RMSE: 4.6183, Test RMSE: 19.8857
Training PCC: 0.9862588935299659, Validation PCC: 0.9882, Test PCC: 0.6990
Checkpoint saved for epoch 10
Total training time: 1204.27 seconds
loading best model from TeacherModel_RMSELoss_test_subject_2_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 21.4450, Test PCC: 0.6990, Test RMSE: 19.8857
Running training with subject_3 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_3. Resharding...
Processing subjects: ['subject_3'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_3/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_3/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.6841, Validation Loss: 10.6311, Test Loss: 22.6564
Training RMSE: 17.290915740699305, Validation RMSE: 10.3268, Test RMSE: 20.4041
Training PCC: 0.7843785398606516, Validation PCC: 0.9428, Test PCC: 0.7349
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.8017, Validation Loss: 8.9436, Test Loss: 23.5558
Training RMSE: 9.526494042175573, Validation RMSE: 8.6555, Test RMSE: 21.0430
Training PCC: 0.9530052092478288, Validation PCC: 0.9620, Test PCC: 0.7458
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.4524, Validation Loss: 7.4818, Test Loss: 21.8911
Training RMSE: 8.186364677378803, Validation RMSE: 7.2841, Test RMSE: 19.7117
Training PCC: 0.9662800962707635, Validation PCC: 0.9701, Test PCC: 0.7445
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.7151, Validation Loss: 7.0895, Test Loss: 22.1037
Training RMSE: 7.466399034833519, Validation RMSE: 6.8753, Test RMSE: 19.8031
Training PCC: 0.9716938377326961, Validation PCC: 0.9745, Test PCC: 0.7608
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.3510, Validation Loss: 6.7237, Test Loss: 22.3736
Training RMSE: 7.116561163974002, Validation RMSE: 6.5448, Test RMSE: 19.9019
Training PCC: 0.9742851196643806, Validation PCC: 0.9763, Test PCC: 0.7496
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 6.8067, Validation Loss: 6.7469, Test Loss: 23.5358
Training RMSE: 6.603069258414632, Validation RMSE: 6.5360, Test RMSE: 21.1340
Training PCC: 0.9776426407749047, Validation PCC: 0.9774, Test PCC: 0.7465
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.6104, Validation Loss: 6.2432, Test Loss: 22.5423
Training RMSE: 6.417054516028583, Validation RMSE: 6.0861, Test RMSE: 20.2090
Training PCC: 0.9792961228157598, Validation PCC: 0.9781, Test PCC: 0.7462
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.2725, Validation Loss: 6.0996, Test Loss: 21.9548
Training RMSE: 6.103024017762363, Validation RMSE: 5.9342, Test RMSE: 19.8051
Training PCC: 0.9809160988648368, Validation PCC: 0.9807, Test PCC: 0.7740
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 5.9481, Validation Loss: 5.7539, Test Loss: 22.2968
Training RMSE: 5.789765610685194, Validation RMSE: 5.6148, Test RMSE: 20.0595
Training PCC: 0.9827133433348005, Validation PCC: 0.9821, Test PCC: 0.7606
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.7125, Validation Loss: 5.3180, Test Loss: 22.0883
Training RMSE: 5.566206937640662, Validation RMSE: 5.1930, Test RMSE: 19.9035
Training PCC: 0.9839502338991423, Validation PCC: 0.9844, Test PCC: 0.7605
Checkpoint saved for epoch 10
Total training time: 1219.32 seconds
loading best model from TeacherModel_RMSELoss_test_subject_3_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 22.0883, Test PCC: 0.7605, Test RMSE: 19.9035
Running training with subject_4 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_4. Resharding...
Processing subjects: ['subject_4'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_4/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_4/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.1537, Validation Loss: 10.8093, Test Loss: 16.6561
Training RMSE: 17.67273755467266, Validation RMSE: 10.4117, Test RMSE: 15.9315
Training PCC: 0.785772250046068, Validation PCC: 0.9416, Test PCC: 0.6911
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.4054, Validation Loss: 8.7172, Test Loss: 16.7972
Training RMSE: 9.966683606002917, Validation RMSE: 8.2910, Test RMSE: 16.1151
Training PCC: 0.9512907412412988, Validation PCC: 0.9622, Test PCC: 0.7228
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.9721, Validation Loss: 7.9476, Test Loss: 16.1502
Training RMSE: 8.554499271812789, Validation RMSE: 7.6123, Test RMSE: 15.3306
Training PCC: 0.9647997961429254, Validation PCC: 0.9694, Test PCC: 0.7368
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.2213, Validation Loss: 7.2484, Test Loss: 15.8886
Training RMSE: 7.838912253715212, Validation RMSE: 6.8747, Test RMSE: 15.2872
Training PCC: 0.9708695204666848, Validation PCC: 0.9746, Test PCC: 0.7338
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.7116, Validation Loss: 6.8388, Test Loss: 16.3667
Training RMSE: 7.358146236577166, Validation RMSE: 6.4666, Test RMSE: 15.5997
Training PCC: 0.9739986715259942, Validation PCC: 0.9778, Test PCC: 0.7538
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.2400, Validation Loss: 6.9122, Test Loss: 16.0253
Training RMSE: 6.9096636779446845, Validation RMSE: 6.5742, Test RMSE: 14.9775
Training PCC: 0.9771130802209527, Validation PCC: 0.9790, Test PCC: 0.7598
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.8385, Validation Loss: 6.6421, Test Loss: 15.2079
Training RMSE: 6.549366607578523, Validation RMSE: 6.2768, Test RMSE: 14.6699
Training PCC: 0.9794804980125754, Validation PCC: 0.9797, Test PCC: 0.7723
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.6387, Validation Loss: 6.2287, Test Loss: 14.9521
Training RMSE: 6.366862543739917, Validation RMSE: 5.9014, Test RMSE: 14.3120
Training PCC: 0.9802085232731582, Validation PCC: 0.9819, Test PCC: 0.7603
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.4188, Validation Loss: 6.2916, Test Loss: 14.4766
Training RMSE: 6.150504468048749, Validation RMSE: 6.0113, Test RMSE: 13.8103
Training PCC: 0.9817158929584219, Validation PCC: 0.9818, Test PCC: 0.7757
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.1665, Validation Loss: 5.9451, Test Loss: 15.1340
Training RMSE: 5.928943914740093, Validation RMSE: 5.6916, Test RMSE: 14.4446
Training PCC: 0.9827669710878654, Validation PCC: 0.9841, Test PCC: 0.7931
Checkpoint saved for epoch 10
Total training time: 1220.17 seconds
loading best model from TeacherModel_RMSELoss_test_subject_4_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 15.1340, Test PCC: 0.7931, Test RMSE: 14.4446
Running training with subject_5 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_5. Resharding...
Processing subjects: ['subject_5'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_5/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_5/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.8338, Validation Loss: 10.8144, Test Loss: 21.6340
Training RMSE: 17.357356354852456, Validation RMSE: 10.4815, Test RMSE: 20.3588
Training PCC: 0.795947452613729, Validation PCC: 0.9401, Test PCC: 0.6604
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.4254, Validation Loss: 9.3328, Test Loss: 21.8931
Training RMSE: 10.01056041387, Validation RMSE: 9.0316, Test RMSE: 20.5369
Training PCC: 0.9499392864290579, Validation PCC: 0.9611, Test PCC: 0.6904
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.1837, Validation Loss: 8.2747, Test Loss: 23.3267
Training RMSE: 8.78773599638244, Validation RMSE: 7.9021, Test RMSE: 21.9924
Training PCC: 0.9621142316436034, Validation PCC: 0.9697, Test PCC: 0.6818
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.4024, Validation Loss: 7.3600, Test Loss: 21.5057
Training RMSE: 8.014077089613002, Validation RMSE: 7.0430, Test RMSE: 20.5211
Training PCC: 0.9688202945310834, Validation PCC: 0.9740, Test PCC: 0.7002
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.8523, Validation Loss: 6.7754, Test Loss: 22.7567
Training RMSE: 7.499486443221144, Validation RMSE: 6.5251, Test RMSE: 21.5339
Training PCC: 0.9725047717394047, Validation PCC: 0.9771, Test PCC: 0.6919
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.4708, Validation Loss: 6.7676, Test Loss: 22.5177
Training RMSE: 7.151358647934646, Validation RMSE: 6.4954, Test RMSE: 21.1305
Training PCC: 0.9745719006796447, Validation PCC: 0.9782, Test PCC: 0.6945
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.0739, Validation Loss: 6.4681, Test Loss: 21.3843
Training RMSE: 6.77808033921789, Validation RMSE: 6.1914, Test RMSE: 20.4249
Training PCC: 0.9771956341963794, Validation PCC: 0.9799, Test PCC: 0.7238
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.8134, Validation Loss: 5.9539, Test Loss: 23.1898
Training RMSE: 6.551992025093444, Validation RMSE: 5.7335, Test RMSE: 22.0950
Training PCC: 0.9784194686005206, Validation PCC: 0.9822, Test PCC: 0.6683
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.6182, Validation Loss: 6.1761, Test Loss: 23.4080
Training RMSE: 6.358350757186687, Validation RMSE: 5.9371, Test RMSE: 22.5071
Training PCC: 0.9797661720834566, Validation PCC: 0.9820, Test PCC: 0.6859
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.2597, Validation Loss: 6.0129, Test Loss: 21.7334
Training RMSE: 6.028833291338611, Validation RMSE: 5.7075, Test RMSE: 21.0498
Training PCC: 0.981406308309798, Validation PCC: 0.9842, Test PCC: 0.6916
Checkpoint saved for epoch 10
Total training time: 1219.43 seconds
loading best model from TeacherModel_RMSELoss_test_subject_5_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 23.1898, Test PCC: 0.6683, Test RMSE: 22.0950
Running training with subject_6 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_6. Resharding...
Processing subjects: ['subject_6'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_6/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_6/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.0838, Validation Loss: 10.3271, Test Loss: 15.4432
Training RMSE: 17.67329964405153, Validation RMSE: 10.0385, Test RMSE: 14.4438
Training PCC: 0.7812889461655318, Validation PCC: 0.9438, Test PCC: 0.7761
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.5781, Validation Loss: 8.4279, Test Loss: 14.4611
Training RMSE: 10.199906885139342, Validation RMSE: 8.1506, Test RMSE: 13.5613
Training PCC: 0.9470395690358537, Validation PCC: 0.9653, Test PCC: 0.7895
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.0997, Validation Loss: 8.0882, Test Loss: 14.7140
Training RMSE: 8.728824497238408, Validation RMSE: 7.6783, Test RMSE: 13.6939
Training PCC: 0.9616721459757341, Validation PCC: 0.9713, Test PCC: 0.7954
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.4595, Validation Loss: 7.6767, Test Loss: 12.9661
Training RMSE: 8.105072735770932, Validation RMSE: 7.2496, Test RMSE: 12.2096
Training PCC: 0.9678826396617305, Validation PCC: 0.9742, Test PCC: 0.7998
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.9288, Validation Loss: 6.9784, Test Loss: 14.3000
Training RMSE: 7.592511240059768, Validation RMSE: 6.6102, Test RMSE: 13.3884
Training PCC: 0.9715586962527323, Validation PCC: 0.9784, Test PCC: 0.8101
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.4195, Validation Loss: 6.7868, Test Loss: 14.4779
Training RMSE: 7.114121187993182, Validation RMSE: 6.5262, Test RMSE: 13.3717
Training PCC: 0.9748184283160285, Validation PCC: 0.9791, Test PCC: 0.7950
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.3674, Validation Loss: 6.7020, Test Loss: 13.3634
Training RMSE: 7.080453058083852, Validation RMSE: 6.3151, Test RMSE: 12.4962
Training PCC: 0.9750770476912485, Validation PCC: 0.9810, Test PCC: 0.8018
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.8560, Validation Loss: 6.2898, Test Loss: 14.1394
Training RMSE: 6.610079710076495, Validation RMSE: 5.9458, Test RMSE: 13.1116
Training PCC: 0.9781188614432286, Validation PCC: 0.9816, Test PCC: 0.7923
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.4516, Validation Loss: 6.2584, Test Loss: 13.5006
Training RMSE: 6.227177020495501, Validation RMSE: 5.9503, Test RMSE: 12.5790
Training PCC: 0.9802115854452046, Validation PCC: 0.9825, Test PCC: 0.8103
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.2260, Validation Loss: 5.8552, Test Loss: 13.7492
Training RMSE: 6.019450961574307, Validation RMSE: 5.5851, Test RMSE: 12.5044
Training PCC: 0.9812100518721282, Validation PCC: 0.9838, Test PCC: 0.8156
Checkpoint saved for epoch 10
Total training time: 1219.43 seconds
loading best model from TeacherModel_RMSELoss_test_subject_6_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 13.7492, Test PCC: 0.8156, Test RMSE: 12.5044
Running training with subject_7 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_7. Resharding...
Processing subjects: ['subject_7'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_7/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_7/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.7286, Validation Loss: 11.2449, Test Loss: 17.7669
Training RMSE: 17.255300210258827, Validation RMSE: 10.6954, Test RMSE: 16.3246
Training PCC: 0.7916980670889481, Validation PCC: 0.9430, Test PCC: 0.7719
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.5597, Validation Loss: 9.8413, Test Loss: 16.7084
Training RMSE: 10.14804539060205, Validation RMSE: 9.3777, Test RMSE: 15.4922
Training PCC: 0.9481710164621585, Validation PCC: 0.9592, Test PCC: 0.7982
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.1940, Validation Loss: 8.2022, Test Loss: 16.6291
Training RMSE: 8.795251429565553, Validation RMSE: 7.8627, Test RMSE: 15.5598
Training PCC: 0.9611194348429913, Validation PCC: 0.9687, Test PCC: 0.7864
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.5063, Validation Loss: 7.3824, Test Loss: 15.2846
Training RMSE: 8.127605839473445, Validation RMSE: 7.0258, Test RMSE: 14.3638
Training PCC: 0.9674652753461421, Validation PCC: 0.9751, Test PCC: 0.8023
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.8782, Validation Loss: 7.1293, Test Loss: 15.4829
Training RMSE: 7.532352921197084, Validation RMSE: 6.7679, Test RMSE: 14.6097
Training PCC: 0.9715651999574076, Validation PCC: 0.9767, Test PCC: 0.7851
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.5303, Validation Loss: 6.8132, Test Loss: 15.4966
Training RMSE: 7.218414695524589, Validation RMSE: 6.4725, Test RMSE: 14.6362
Training PCC: 0.9741864712442124, Validation PCC: 0.9785, Test PCC: 0.7699
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.2452, Validation Loss: 6.2548, Test Loss: 15.9175
Training RMSE: 6.953105025902026, Validation RMSE: 5.9772, Test RMSE: 15.0623
Training PCC: 0.9759884743287652, Validation PCC: 0.9808, Test PCC: 0.7763
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.7992, Validation Loss: 6.2371, Test Loss: 14.7339
Training RMSE: 6.5498907946474185, Validation RMSE: 6.0255, Test RMSE: 13.8298
Training PCC: 0.9783409757502411, Validation PCC: 0.9815, Test PCC: 0.7891
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.5036, Validation Loss: 5.9766, Test Loss: 15.7414
Training RMSE: 6.272605590704011, Validation RMSE: 5.7065, Test RMSE: 14.6723
Training PCC: 0.9800011554130266, Validation PCC: 0.9826, Test PCC: 0.7893
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.2529, Validation Loss: 5.9685, Test Loss: 15.9977
Training RMSE: 6.0432999877910305, Validation RMSE: 5.7353, Test RMSE: 14.7836
Training PCC: 0.9813810105879632, Validation PCC: 0.9834, Test PCC: 0.7925
Checkpoint saved for epoch 10
Total training time: 1208.25 seconds
loading best model from TeacherModel_RMSELoss_test_subject_7_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 15.9977, Test PCC: 0.7925, Test RMSE: 14.7836
Running training with subject_8 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_8. Resharding...
Processing subjects: ['subject_8'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_8/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_8/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.0725, Validation Loss: 11.3041, Test Loss: 13.1543
Training RMSE: 17.588849153460526, Validation RMSE: 10.8690, Test RMSE: 12.1174
Training PCC: 0.7907705582662001, Validation PCC: 0.9410, Test PCC: 0.8224
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.6330, Validation Loss: 9.9474, Test Loss: 13.7876
Training RMSE: 10.215224464734396, Validation RMSE: 9.3776, Test RMSE: 12.9431
Training PCC: 0.9487282716877431, Validation PCC: 0.9600, Test PCC: 0.8357
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.2920, Validation Loss: 8.5021, Test Loss: 13.3465
Training RMSE: 8.877032818833017, Validation RMSE: 8.0403, Test RMSE: 12.6407
Training PCC: 0.9620554452965834, Validation PCC: 0.9691, Test PCC: 0.8387
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.5112, Validation Loss: 8.9782, Test Loss: 14.1780
Training RMSE: 8.136490797608849, Validation RMSE: 8.5836, Test RMSE: 13.5460
Training PCC: 0.9681490839334774, Validation PCC: 0.9725, Test PCC: 0.8289
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.9880, Validation Loss: 7.6180, Test Loss: 13.9844
Training RMSE: 7.62924477724525, Validation RMSE: 7.1934, Test RMSE: 13.3134
Training PCC: 0.9720296186054996, Validation PCC: 0.9764, Test PCC: 0.8228
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.5677, Validation Loss: 6.9313, Test Loss: 14.0527
Training RMSE: 7.230314601485323, Validation RMSE: 6.6161, Test RMSE: 13.4226
Training PCC: 0.9750256279738964, Validation PCC: 0.9790, Test PCC: 0.8004
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.2049, Validation Loss: 6.6832, Test Loss: 12.9937
Training RMSE: 6.893785413687791, Validation RMSE: 6.3679, Test RMSE: 12.2880
Training PCC: 0.97720407283383, Validation PCC: 0.9803, Test PCC: 0.8366
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.8286, Validation Loss: 7.0321, Test Loss: 14.3928
Training RMSE: 6.557434469219146, Validation RMSE: 6.6965, Test RMSE: 13.6018
Training PCC: 0.9790147896954121, Validation PCC: 0.9817, Test PCC: 0.8257
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.5541, Validation Loss: 6.5795, Test Loss: 14.1621
Training RMSE: 6.300254892042982, Validation RMSE: 6.3200, Test RMSE: 13.3708
Training PCC: 0.9806005095484623, Validation PCC: 0.9822, Test PCC: 0.8242
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.3808, Validation Loss: 6.0079, Test Loss: 13.1628
Training RMSE: 6.150580497291998, Validation RMSE: 5.7648, Test RMSE: 12.3819
Training PCC: 0.9812579410613577, Validation PCC: 0.9839, Test PCC: 0.8217
Checkpoint saved for epoch 10
Total training time: 1203.97 seconds
loading best model from TeacherModel_RMSELoss_test_subject_8_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 13.1628, Test PCC: 0.8217, Test RMSE: 12.3819
Running training with subject_9 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_9. Resharding...
Processing subjects: ['subject_9'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_9/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_9/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.1646, Validation Loss: 10.9959, Test Loss: 15.4355
Training RMSE: 17.684431041643872, Validation RMSE: 10.6484, Test RMSE: 14.9658
Training PCC: 0.7930272456889383, Validation PCC: 0.9394, Test PCC: 0.7043
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.5874, Validation Loss: 8.8304, Test Loss: 14.0743
Training RMSE: 10.176530063152313, Validation RMSE: 8.4936, Test RMSE: 13.5951
Training PCC: 0.9495181853937035, Validation PCC: 0.9627, Test PCC: 0.7468
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.3409, Validation Loss: 8.5284, Test Loss: 13.3047
Training RMSE: 8.95678939034299, Validation RMSE: 8.1609, Test RMSE: 12.5972
Training PCC: 0.9624474895899043, Validation PCC: 0.9671, Test PCC: 0.7687
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.5633, Validation Loss: 8.0893, Test Loss: 13.6905
Training RMSE: 8.185432993784184, Validation RMSE: 7.7608, Test RMSE: 12.9320
Training PCC: 0.9680218986775463, Validation PCC: 0.9726, Test PCC: 0.7610
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.9254, Validation Loss: 7.6637, Test Loss: 12.7625
Training RMSE: 7.601224972949765, Validation RMSE: 7.3296, Test RMSE: 12.3676
Training PCC: 0.9721846057549987, Validation PCC: 0.9736, Test PCC: 0.7573
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.6368, Validation Loss: 7.2645, Test Loss: 12.8392
Training RMSE: 7.321398336955202, Validation RMSE: 6.8103, Test RMSE: 12.3859
Training PCC: 0.9737678269072614, Validation PCC: 0.9772, Test PCC: 0.7524
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.2361, Validation Loss: 6.8166, Test Loss: 12.6619
Training RMSE: 6.948029769629966, Validation RMSE: 6.4887, Test RMSE: 12.1353
Training PCC: 0.976955426083448, Validation PCC: 0.9793, Test PCC: 0.7748
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 7.1395, Validation Loss: 6.7176, Test Loss: 13.0824
Training RMSE: 6.870806780045595, Validation RMSE: 6.4358, Test RMSE: 12.4215
Training PCC: 0.9769414352182676, Validation PCC: 0.9783, Test PCC: 0.7622
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.7931, Validation Loss: 6.6354, Test Loss: 13.2425
Training RMSE: 6.53877958776505, Validation RMSE: 6.2704, Test RMSE: 12.7306
Training PCC: 0.9795360807616779, Validation PCC: 0.9821, Test PCC: 0.7758
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.3258, Validation Loss: 6.4683, Test Loss: 13.8934
Training RMSE: 6.110457265522421, Validation RMSE: 6.1342, Test RMSE: 13.1804
Training PCC: 0.9818259391544032, Validation PCC: 0.9822, Test PCC: 0.7471
Checkpoint saved for epoch 10
Total training time: 1200.82 seconds
loading best model from TeacherModel_RMSELoss_test_subject_9_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 13.8934, Test PCC: 0.7471, Test RMSE: 13.1804
Running training with subject_10 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_10. Resharding...
Processing subjects: ['subject_10'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_10/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_10/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.1128, Validation Loss: 11.2789, Test Loss: 14.7266
Training RMSE: 17.626015535699644, Validation RMSE: 10.9036, Test RMSE: 13.9086
Training PCC: 0.7906772973012172, Validation PCC: 0.9427, Test PCC: 0.7771
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.4735, Validation Loss: 8.9943, Test Loss: 15.5707
Training RMSE: 10.059330721696218, Validation RMSE: 8.5822, Test RMSE: 14.3808
Training PCC: 0.9501394973712153, Validation PCC: 0.9650, Test PCC: 0.7430
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.2088, Validation Loss: 8.1474, Test Loss: 15.2187
Training RMSE: 8.813030270056997, Validation RMSE: 7.6978, Test RMSE: 14.4621
Training PCC: 0.9624911719871466, Validation PCC: 0.9711, Test PCC: 0.7367
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.5111, Validation Loss: 7.7002, Test Loss: 16.0560
Training RMSE: 8.141218680676404, Validation RMSE: 7.2613, Test RMSE: 15.0281
Training PCC: 0.9682723486960126, Validation PCC: 0.9746, Test PCC: 0.7729
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 8.0108, Validation Loss: 7.7813, Test Loss: 15.7780
Training RMSE: 7.6545336171378935, Validation RMSE: 7.4254, Test RMSE: 14.6016
Training PCC: 0.9715214574577152, Validation PCC: 0.9764, Test PCC: 0.7710
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.5927, Validation Loss: 6.9318, Test Loss: 16.4361
Training RMSE: 7.269204784457277, Validation RMSE: 6.5770, Test RMSE: 15.5500
Training PCC: 0.9743271854051446, Validation PCC: 0.9791, Test PCC: 0.7439
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.2917, Validation Loss: 6.7867, Test Loss: 16.7921
Training RMSE: 6.990203679334827, Validation RMSE: 6.5159, Test RMSE: 15.8141
Training PCC: 0.9764020711887035, Validation PCC: 0.9798, Test PCC: 0.7389
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.9512, Validation Loss: 6.3241, Test Loss: 16.2471
Training RMSE: 6.6839085095297035, Validation RMSE: 6.0622, Test RMSE: 15.0677
Training PCC: 0.9779608769578437, Validation PCC: 0.9819, Test PCC: 0.7442
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.5366, Validation Loss: 5.9275, Test Loss: 15.6165
Training RMSE: 6.315844179653539, Validation RMSE: 5.6962, Test RMSE: 14.6404
Training PCC: 0.9801705116386903, Validation PCC: 0.9837, Test PCC: 0.7594
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.2536, Validation Loss: 5.9647, Test Loss: 16.1088
Training RMSE: 6.053063072809359, Validation RMSE: 5.7461, Test RMSE: 15.1042
Training PCC: 0.9816097218902256, Validation PCC: 0.9828, Test PCC: 0.7407
Checkpoint saved for epoch 10
Total training time: 1209.95 seconds
loading best model from TeacherModel_RMSELoss_test_subject_10_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 15.6165, Test PCC: 0.7594, Test RMSE: 14.6404
Running training with subject_11 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_11. Resharding...
Processing subjects: ['subject_11'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_11/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_11/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.8332, Validation Loss: 11.7552, Test Loss: 13.0317
Training RMSE: 17.355173749167744, Validation RMSE: 11.2974, Test RMSE: 12.5615
Training PCC: 0.7910202190635677, Validation PCC: 0.9329, Test PCC: 0.7552
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.5999, Validation Loss: 9.2961, Test Loss: 14.5268
Training RMSE: 10.17019217624897, Validation RMSE: 8.8550, Test RMSE: 13.5568
Training PCC: 0.9475956686438044, Validation PCC: 0.9576, Test PCC: 0.7801
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.1499, Validation Loss: 8.6214, Test Loss: 14.0467
Training RMSE: 8.755210925893085, Validation RMSE: 8.2023, Test RMSE: 13.3179
Training PCC: 0.9619909059387685, Validation PCC: 0.9655, Test PCC: 0.7766
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.4933, Validation Loss: 7.9287, Test Loss: 14.2980
Training RMSE: 8.125058543875936, Validation RMSE: 7.5507, Test RMSE: 13.3793
Training PCC: 0.9677576302623957, Validation PCC: 0.9704, Test PCC: 0.7624
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.9769, Validation Loss: 7.6928, Test Loss: 13.5858
Training RMSE: 7.6298062803784035, Validation RMSE: 7.2608, Test RMSE: 12.8457
Training PCC: 0.9710120865269637, Validation PCC: 0.9739, Test PCC: 0.7527
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.5220, Validation Loss: 7.3035, Test Loss: 14.4124
Training RMSE: 7.198909044992632, Validation RMSE: 6.9291, Test RMSE: 13.4888
Training PCC: 0.9740993450014205, Validation PCC: 0.9748, Test PCC: 0.7949
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.1969, Validation Loss: 7.1185, Test Loss: 13.2865
Training RMSE: 6.896622096135364, Validation RMSE: 6.8386, Test RMSE: 12.5033
Training PCC: 0.9761349954909041, Validation PCC: 0.9743, Test PCC: 0.7985
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.9531, Validation Loss: 6.7114, Test Loss: 13.5279
Training RMSE: 6.683941884980938, Validation RMSE: 6.4977, Test RMSE: 12.5805
Training PCC: 0.9774104054183965, Validation PCC: 0.9787, Test PCC: 0.8019
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.5673, Validation Loss: 6.6010, Test Loss: 12.8955
Training RMSE: 6.329130231849547, Validation RMSE: 6.2620, Test RMSE: 12.0855
Training PCC: 0.9794135954242278, Validation PCC: 0.9803, Test PCC: 0.8055
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.1775, Validation Loss: 6.1394, Test Loss: 13.7118
Training RMSE: 5.9800070613864955, Validation RMSE: 5.9102, Test RMSE: 12.7725
Training PCC: 0.9817624768415228, Validation PCC: 0.9806, Test PCC: 0.8018
Checkpoint saved for epoch 10
Total training time: 1203.36 seconds
loading best model from TeacherModel_RMSELoss_test_subject_11_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 13.7118, Test PCC: 0.8018, Test RMSE: 12.7725
Running training with subject_12 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_12. Resharding...
Processing subjects: ['subject_12'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_12/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_12/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.7395, Validation Loss: 11.0964, Test Loss: 12.5204
Training RMSE: 18.23733099931624, Validation RMSE: 10.7633, Test RMSE: 12.1408
Training PCC: 0.7766924944679635, Validation PCC: 0.9414, Test PCC: 0.7484
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.7499, Validation Loss: 9.0721, Test Loss: 12.9734
Training RMSE: 10.313013070482548, Validation RMSE: 8.6773, Test RMSE: 12.4912
Training PCC: 0.9479935667540428, Validation PCC: 0.9630, Test PCC: 0.7081
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.4472, Validation Loss: 7.9058, Test Loss: 12.5613
Training RMSE: 9.025429150922513, Validation RMSE: 7.5770, Test RMSE: 12.0074
Training PCC: 0.9610111581690571, Validation PCC: 0.9717, Test PCC: 0.7370
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.5870, Validation Loss: 7.5299, Test Loss: 13.6333
Training RMSE: 8.1982946449179, Validation RMSE: 7.1420, Test RMSE: 12.9254
Training PCC: 0.9680891324478109, Validation PCC: 0.9756, Test PCC: 0.7725
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 8.0202, Validation Loss: 6.9704, Test Loss: 11.9671
Training RMSE: 7.662986198091896, Validation RMSE: 6.6884, Test RMSE: 11.4740
Training PCC: 0.9718880194749362, Validation PCC: 0.9773, Test PCC: 0.7943
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.6666, Validation Loss: 6.5347, Test Loss: 12.4877
Training RMSE: 7.33779827123735, Validation RMSE: 6.1877, Test RMSE: 11.8651
Training PCC: 0.9743981031974546, Validation PCC: 0.9812, Test PCC: 0.7773
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.2754, Validation Loss: 6.5934, Test Loss: 11.9326
Training RMSE: 6.978141870682801, Validation RMSE: 6.3117, Test RMSE: 11.5351
Training PCC: 0.9762698802733495, Validation PCC: 0.9806, Test PCC: 0.7527
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.8890, Validation Loss: 5.9060, Test Loss: 12.0966
Training RMSE: 6.612795064362083, Validation RMSE: 5.6581, Test RMSE: 11.5460
Training PCC: 0.9788060005214891, Validation PCC: 0.9840, Test PCC: 0.7767
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.5655, Validation Loss: 5.9356, Test Loss: 11.9715
Training RMSE: 6.31864827338273, Validation RMSE: 5.6741, Test RMSE: 11.5151
Training PCC: 0.9804049700037699, Validation PCC: 0.9836, Test PCC: 0.7889
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.2799, Validation Loss: 5.5507, Test Loss: 12.1898
Training RMSE: 6.053573657342088, Validation RMSE: 5.2820, Test RMSE: 11.5790
Training PCC: 0.9820155380239847, Validation PCC: 0.9860, Test PCC: 0.7906
Checkpoint saved for epoch 10
Total training time: 1211.58 seconds
loading best model from TeacherModel_RMSELoss_test_subject_12_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 12.1898, Test PCC: 0.7906, Test RMSE: 11.5790
Running training with subject_13 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_nbs
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_13. Resharding...
Processing subjects: ['subject_13'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_13/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_13/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_nbs
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 18.0149, Validation Loss: 12.6978, Test Loss: 20.3414
Training RMSE: 17.540719718467898, Validation RMSE: 11.9016, Test RMSE: 18.0952
Training PCC: 0.7845362711929349, Validation PCC: 0.9348, Test PCC: 0.7276
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 10.6523, Validation Loss: 9.1070, Test Loss: 17.9898
Training RMSE: 10.24279171373786, Validation RMSE: 8.6834, Test RMSE: 16.3978
Training PCC: 0.9463324473608538, Validation PCC: 0.9620, Test PCC: 0.7451
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 9.0569, Validation Loss: 8.0154, Test Loss: 16.9066
Training RMSE: 8.65307690554518, Validation RMSE: 7.6338, Test RMSE: 15.5319
Training PCC: 0.9615866956832285, Validation PCC: 0.9694, Test PCC: 0.7377
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.5092, Validation Loss: 8.1044, Test Loss: 16.7942
Training RMSE: 8.125935962529686, Validation RMSE: 7.7733, Test RMSE: 15.5626
Training PCC: 0.9674686543989965, Validation PCC: 0.9705, Test PCC: 0.7416
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.8736, Validation Loss: 7.2553, Test Loss: 17.2052
Training RMSE: 7.509485402000628, Validation RMSE: 6.9432, Test RMSE: 15.7063
Training PCC: 0.9717091052969179, Validation PCC: 0.9757, Test PCC: 0.7514
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.5625, Validation Loss: 7.0452, Test Loss: 16.9700
Training RMSE: 7.240835145237, Validation RMSE: 6.7456, Test RMSE: 15.7841
Training PCC: 0.973602815922054, Validation PCC: 0.9764, Test PCC: 0.7458
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 7.1523, Validation Loss: 6.4192, Test Loss: 16.2018
Training RMSE: 6.85240993291382, Validation RMSE: 6.1455, Test RMSE: 14.8099
Training PCC: 0.9764860868140351, Validation PCC: 0.9807, Test PCC: 0.7687
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.7872, Validation Loss: 6.3651, Test Loss: 16.3844
Training RMSE: 6.517724390194668, Validation RMSE: 6.0777, Test RMSE: 15.0794
Training PCC: 0.9782826805387727, Validation PCC: 0.9817, Test PCC: 0.7545
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.5429, Validation Loss: 6.2098, Test Loss: 15.6395
Training RMSE: 6.294749651497942, Validation RMSE: 5.9578, Test RMSE: 14.0543
Training PCC: 0.9795216979545351, Validation PCC: 0.9817, Test PCC: 0.7735
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.3070, Validation Loss: 5.9378, Test Loss: 15.3641
Training RMSE: 6.072289077247063, Validation RMSE: 5.7168, Test RMSE: 14.1287
Training PCC: 0.9809508175113438, Validation PCC: 0.9830, Test PCC: 0.7664
Checkpoint saved for epoch 10
Total training time: 1205.11 seconds
loading best model from TeacherModel_RMSELoss_test_subject_13_wl100_ol75_nbs


  model.load_state_dict(torch.load(filename))


Test Loss: 15.3641, Test PCC: 0.7664, Test RMSE: 14.1287


In [11]:

average_best_rmse = np.mean(best_rmse_per_subject)
average_best_pcc = np.mean(best_pcc_per_subject)
print(f"Average of best RMSEs across all subjects: {average_best_rmse:.4f}")
print(f"Average of best PCCs across all subjects: {average_best_pcc:.4f}")
print(best_rmse_per_subject)
print(best_pcc_per_subject)

# subjects = [f'Subject {i+1}' for i in range(len(best_rmse_per_subject))]

# print(best_rmse_per_subject)
# # Plot a bar chart with subject labels on the x-axis
# plt.figure(figsize=(10, 6))
# plt.bar(subjects, best_rmse_per_subject, color='blue', edgecolor='black')
# plt.title('Best RMSEs for Each Subject')
# plt.xlabel('Subjects')
# plt.ylabel('Best RMSE')
# plt.xticks(rotation=45, ha='right')
# plt.grid(True, axis='y')
# plt.tight_layout()
# plt.show()

Average of best RMSEs across all subjects: 15.3246
Average of best PCCs across all subjects: 0.7654
[16.919614732265472, 19.885711590449016, 19.903529405593872, 14.444613258043924, 22.095018724600475, 12.504369894663492, 14.783570726712545, 12.38186796506246, 13.180358548959097, 14.640406886736551, 12.772493461767832, 11.57896512746811, 14.128747244675955]
[0.7337576336799977, 0.6990383187474953, 0.7604733728134505, 0.793063701447651, 0.6682668441038441, 0.8156133774056302, 0.7924584902594888, 0.8217004172583385, 0.7470915954404665, 0.7594021283850347, 0.8018022254598905, 0.7905668618099767, 0.7664172683285053]


In [12]:
import os
import zipfile
from datetime import datetime

notebook_name = 'regression_benchmark_normalizebysubject'

# Create a timestamped folder name based on the notebook name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"{notebook_name}_checkpoints_{timestamp}"

# Make sure the folder exists
os.makedirs(folder_name, exist_ok=True)

checkpoint_dir = '.'

# Zip all checkpoint files and save in the new folder
zip_filename = f"{folder_name}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # List files only in the current directory (no subfolders)
    for file in os.listdir(checkpoint_dir):
        if "TeacherModel" in str(file):
          file_path = os.path.join(checkpoint_dir, file)
          zipf.write(file_path, os.path.relpath(file_path, checkpoint_dir))
          print(f"Checkpoint {file} has been added to the zip file.")
print(f"All checkpoints have been zipped and saved as {zip_filename}.")




Checkpoint TeacherModel_RMSELoss_test_subject_3_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_6_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_13_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_5_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_2_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_8_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_4_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_10_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_11_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_1_wl100_ol75_nbs has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_1

In [13]:
#copy zip file into google drive
import shutil

destination_path = '/content/MyDrive/MyDrive/models'

shutil.copy(zip_filename, destination_path)

'/content/MyDrive/MyDrive/models/regression_benchmark_normalizebysubject_checkpoints_20241024_114113.zip'