In [1]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Drive already mounted at /content/MyDrive; to attempt to forcibly remount, call drive.mount("/content/MyDrive", force_remount=True).


In [2]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data_final.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [3]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
      extracted_data = []
      new_column_names = []  # Initialize here

      if isinstance(session_data, h5py.Dataset):
          if session_data.dtype.names:
              # Compound dataset
              column_names = session_data.dtype.names
              for channel in channels:
                  if channel in column_names:
                      channel_data = session_data[channel][:]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)  # Populate here
                  else:
                      print(f"Channel {channel} not found in compound dataset.")
          else:
              # Simple dataset
              column_names = list(session_data.attrs.get('column_names', []))
              assert len(column_names) > 0, "column_names not found in dataset attributes"
              for channel in channels:
                  if channel in column_names:
                      col_idx = column_names.index(channel)
                      channel_data = session_data[:, col_idx]
                      channel_data = pd.to_numeric(channel_data, errors='coerce')
                      df = pd.DataFrame(channel_data)
                      df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                      extracted_data.append(df_interpolated.to_numpy().flatten())
                      new_column_names.append(channel)
                  else:
                      print(f"Channel {channel} not found in session data.")

      return np.array(extracted_data), new_column_names


In [4]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader



In [5]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used
class FeatureModulationLayer(nn.Module):
    def __init__(self, num_features):
        super(FeatureModulationLayer, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.Sigmoid()
        )

    def forward(self, x):
        modulation_weights = self.gate(x.mean(dim=1))  # Compute gate per feature
        return x * modulation_weights.unsqueeze(1)

class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, num_subjects=12, drop_prob=0.25, w=100):
        super(teacher, self).__init__()

        self.w = w
        self.encoder_1_acc = Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr = Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg = Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc = Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr = Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg = Encoder_2(input_emg, drop_prob)

        # Gating and modulation layers
        self.gate_1 = GatingModule(128)
        self.gate_2 = GatingModule(128)
        self.gate_3 = GatingModule(128)

        self.modulation_acc = FeatureModulationLayer(128)
        self.modulation_gyr = FeatureModulationLayer(128)
        self.modulation_emg = FeatureModulationLayer(128)

        # Combine the features after modulation
        self.fc_kd = nn.Linear(3 * 128, 256)  # Updated dimension after concatenation
        self.attention = nn.MultiheadAttention(3 * 128, 4, batch_first=True)

        # Final output layer after attention
        self.fc_final = nn.Linear(3 * 128, 3)

    def forward(self, x_acc, x_gyr, x_emg):
        # Reshape inputs for LSTM layers
        x_acc = x_acc.view(-1, self.w, x_acc.size(-1))
        x_gyr = x_gyr.view(-1, self.w, x_gyr.size(-1))
        x_emg = x_emg.view(-1, self.w, x_emg.size(-1))

        # Encoding
        x_acc_1, _ = self.encoder_1_acc(x_acc)
        x_gyr_1, _ = self.encoder_1_gyr(x_gyr)
        x_emg_1, _ = self.encoder_1_emg(x_emg)

        x_acc_2, _ = self.encoder_2_acc(x_acc)
        x_gyr_2, _ = self.encoder_2_gyr(x_gyr)
        x_emg_2, _ = self.encoder_2_emg(x_emg)

        # Gating and modulation
        x_acc = self.modulation_acc(self.gate_1(x_acc_1, x_acc_2))
        x_gyr = self.modulation_gyr(self.gate_2(x_gyr_1, x_gyr_2))
        x_emg = self.modulation_emg(self.gate_3(x_emg_1, x_emg_2))

        # Concatenate features
        combined_features = torch.cat((x_acc, x_gyr, x_emg), dim=-1)

        # Apply linear transformation for knowledge distillation
        x_kd = self.fc_kd(combined_features)

        # Attention on combined features
        attn_out, _ = self.attention(combined_features, combined_features, combined_features)

        # Final output
        final_out = self.fc_final(attn_out)

        return final_out, x_kd, (x_acc, x_gyr, x_emg)




In [6]:
# @title Loss Functions
import statistics

class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

In [7]:





# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse



def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(device, train_loader, val_loader, test_loader, learn_rate, epochs, model, filename, loss_function, optimizer=None, l1_lambda=None, train_from_last_epoch=False, curriculum_loader=None):
    model.to(device)
    criterion = loss_function

    if optimizer is None:
        # Create a default Adam optimizer if none is passed
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    train_pccs_channelwise = []
    val_pccs_channelwise = []
    test_pccs_channelwise = []

    train_rmses_channelwise = []
    val_rmses_channelwise = []
    test_rmses_channelwise = []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    if train_from_last_epoch and os.path.exists(checkpoint_path):
        # Scan for the latest saved checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch number
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']  # Continue from the next epoch

            # Load the history from checkpoint
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
            train_pccs = checkpoint['history']['train_pccs']
            val_pccs = checkpoint['history']['val_pccs']
            test_pccs = checkpoint['history']['test_pccs']
            train_rmses = checkpoint['history']['train_rmses']
            val_rmses = checkpoint['history']['val_rmses']
            test_rmses = checkpoint['history']['test_rmses']
            train_pccs_channelwise = checkpoint['history']['train_pccs_channelwise']
            val_pccs_channelwise = checkpoint['history']['val_pccs_channelwise']
            test_pccs_channelwise = checkpoint['history']['test_pccs_channelwise']
            train_rmses_channelwise = checkpoint['history']['train_rmses_channelwise']
            val_rmses_channelwise = checkpoint['history']['val_rmses_channelwise']
            test_rmses_channelwise = checkpoint['history']['test_rmses_channelwise']
            if 'curriculum_schedule' in checkpoint:
                curriculum_loader.curriculum_schedule = checkpoint['curriculum_schedule']  # Load saved curriculum schedule
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(last_epoch, epochs):
        epoch_start_time = time.time()
        model.train()

        if curriculum_loader:
            curriculum_loader.update_epoch(epoch)
            train_loader, val_loader, test_loader = curriculum_loader.get_loaders()
        # Track metrics per channel
        epoch_train_loss = np.zeros(len(config.channels_joints))
        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        # Use epoch starting from `epoch + 1` since we want to reflect actual starting epoch correctly
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            # Ensure inputs are properly sent to device and are of correct type
            output = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())


            # Check if output is a tuple, take the first element if true
            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())

            else:
                loss = criterion(output, target.to(device).float())

            # Compute loss



            # Apply L1 regularization if specified
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss = loss + l1_lambda * l1_norm
            else:
                total_loss = loss

            # Backpropagate the gradients for total_loss
            total_loss.backward()
            optimizer.step()

            # Detach tensors and move to CPU to prevent issues with gradient computation
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)

            # Accumulate loss, pcc, and rmse without modifying in-place
            epoch_train_loss += loss.detach().cpu().numpy()
            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_loss)
        train_pccs.append(np.mean(avg_train_pcc))  # Overall average PCC
        train_rmses.append(np.mean(avg_train_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        train_pccs_channelwise.append(avg_train_pcc)  # Per channel
        train_rmses_channelwise.append(avg_train_rmse)  # Per channel

        # Evaluate on validation set every epoch
        avg_val_loss, avg_val_pcc, avg_val_rmse = evaluate_model(device, model, val_loader, criterion)
        val_losses.append(avg_val_loss)
        val_pccs.append(np.mean(avg_val_pcc))  # Overall average PCC
        val_rmses.append(np.mean(avg_val_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        val_pccs_channelwise.append(avg_val_pcc)  # Per channel
        val_rmses_channelwise.append(avg_val_rmse)  # Per channel

        # Evaluate on test set and checkpoint every epoch
        avg_test_loss, avg_test_pcc, avg_test_rmse = evaluate_model(device, model, test_loader, criterion)
        test_losses.append(avg_test_loss)
        test_pccs.append(np.mean(avg_test_pcc))  # Overall average PCC
        test_rmses.append(np.mean(avg_test_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        test_pccs_channelwise.append(avg_test_pcc)  # Per channel
        test_rmses_channelwise.append(avg_test_rmse)  # Per channel

        print(f"Epoch: {epoch + 1}, Training Loss: {np.mean(avg_train_loss):.4f}, Validation Loss: {np.mean(avg_val_loss):.4f}, Test Loss: {np.mean(avg_test_loss):.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse)}, Validation RMSE: {np.mean(avg_val_rmse):.4f}, Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc)}, Validation PCC: {np.mean(avg_val_pcc):.4f}, Test PCC: {np.mean(avg_test_pcc):.4f}")

        # Save checkpoint, including curriculum schedule
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save checkpoint with the curriculum schedule
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'test_losses': test_losses,
            'train_pccs': train_pccs,
            'val_pccs': val_pccs,
            'test_pccs': test_pccs,
            'train_rmses': train_rmses,
            'val_rmses': val_rmses,
            'test_rmses': test_rmses,
            'train_pccs_channelwise': train_pccs_channelwise,
            'val_pccs_channelwise': val_pccs_channelwise,
            'test_pccs_channelwise': test_pccs_channelwise,
            'train_rmses_channelwise': train_rmses_channelwise,
            'val_rmses_channelwise': val_rmses_channelwise,
            'test_rmses_channelwise': test_rmses_channelwise
        }

        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth",
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            test_loss=avg_test_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            },
            history=history,  # Save history in the checkpoint
            curriculum_schedule=curriculum_loader.curriculum_schedule if curriculum_loader else None # Save curriculum schedule
        )

        # Early stopping logic
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    print(f"loading best model from {filename}")
    model.load_state_dict(torch.load(filename))
    model.eval()
    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses






In [8]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model




In [9]:
import matplotlib.pyplot as plt
import numpy as np
import os
import h5py
from tqdm.notebook import tqdm
import pandas as pd
import csv

all_subjects= [f"subject_{x}" for x in range(1,14)]
input_acc, input_gyr, input_emg = 18,18,3
batch_size = 64

# Placeholder for storing best RMSEs
best_rmse_per_subject = []
best_pcc_per_subject = []

train_flag = True

for test_subject in all_subjects:



    print(f"Running training with {test_subject} as the test subject.")

    # Set up the training subjects (all except the test subject)
    train_subjects = [subject for subject in all_subjects if subject != test_subject]

    model_name = f'TeacherModel_RMSELoss_test_{test_subject}_wl{100}_ol{75}_fml'
    print(f"Model: {model_name}")

    # Load the model configuration and data loaders
    model_config = {
        'model': create_teacher_model(input_acc, input_gyr, input_emg, w=100),
        'loss': RMSELoss(),
        'loaders': create_base_data_loaders(
            config=config,
            train_subjects=train_subjects,
            test_subjects=[test_subject],
            window_length=100,
            window_overlap=75,
            batch_size=batch_size
        ),
        'epochs': 10,
        'use_curriculum': False
    }

    model = model_config['model']
    loss_function = model_config['loss']
    epochs = model_config.get("epochs", 10)
    device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    learn_rate = model_config.get("learn_rate", 0.001)
    use_curriculum = model_config.get("use_curriculum", False)

    optimizer = model_config.get("optimizer", None)
    l1_lambda = model_config.get("l1_lambda", None)

    print(f"Running model: {model_name}")

    # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
    train_loader, val_loader, test_loader = model_config['loaders']
    if train_flag:
    # Train the model and save only the best based on validation loss
      model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses = train_teacher(
          device=device,
          train_loader=train_loader,
          val_loader=val_loader,
          test_loader=test_loader,
          learn_rate=learn_rate,
          epochs=epochs,
          model=model,
          filename=model_name,
          loss_function=loss_function,
          optimizer=optimizer,
          l1_lambda=l1_lambda,
          train_from_last_epoch=False
      )
    else:
      #load filename as model
      model.load_state_dict(torch.load(f"{model_name}"))
      model.to(device)
      model.eval()

     #run model on test set and record result
    test_loss, test_pcc, test_rmse = evaluate_model(device, model, test_loader, loss_function)
    print(f"Test Loss: {test_loss:.4f}, Test PCC: {np.mean(test_pcc):.4f}, Test RMSE: {np.mean(test_rmse):.4f}")
    best_rmse_per_subject.append(np.mean(test_rmse))
    best_pcc_per_subject.append(np.mean(test_pcc))


# Compute the average of the best RMSEs across all subjects



Running training with subject_1 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_fml
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Running model: TeacherModel_RMSELoss_test_subject_1_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.7218, Validation Loss: 9.9827, Test Loss: 18.8491
Training RMSE: 16.201962782580797, Validation RMSE: 9.5688, Test RMSE: 17.7982
Training PCC: 0.7928811520174484, Validation PCC: 0.9540, Test PCC: 0.7010
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.2856, Validation Loss: 8.8162, Test Loss: 17.4984
Training RMSE: 8.933117323573091, Validation RMSE: 8.4102, Test RMSE: 16.4867
Training PCC: 0.9592637410272215, Validation PCC: 0.9654, Test PCC: 0.7183
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.3691, Validation Loss: 8.6640, Test Loss: 16.4681
Training RMSE: 8.031567322733935, Validation RMSE: 8.2373, Test RMSE: 15.3573
Training PCC: 0.9676091305820392, Validation PCC: 0.9679, Test PCC: 0.7290
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.8340, Validation Loss: 7.7031, Test Loss: 16.5533
Training RMSE: 7.524432252792529, Validation RMSE: 7.3477, Test RMSE: 15.8314
Training PCC: 0.9716469639997474, Validation PCC: 0.9745, Test PCC: 0.7007
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.3499, Validation Loss: 7.2482, Test Loss: 15.3033
Training RMSE: 7.084087616320658, Validation RMSE: 6.9424, Test RMSE: 14.7029
Training PCC: 0.9748119621682144, Validation PCC: 0.9767, Test PCC: 0.7323
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 6.9012, Validation Loss: 7.0933, Test Loss: 16.6621
Training RMSE: 6.668188396457746, Validation RMSE: 6.7967, Test RMSE: 15.8759
Training PCC: 0.9773009057961164, Validation PCC: 0.9776, Test PCC: 0.7244
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.6821, Validation Loss: 6.9665, Test Loss: 16.2835
Training RMSE: 6.463762757245919, Validation RMSE: 6.6494, Test RMSE: 15.6275
Training PCC: 0.9786077749509118, Validation PCC: 0.9778, Test PCC: 0.7307
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.4121, Validation Loss: 6.8356, Test Loss: 15.9518
Training RMSE: 6.212079316225743, Validation RMSE: 6.5263, Test RMSE: 15.1648
Training PCC: 0.9802438951932224, Validation PCC: 0.9802, Test PCC: 0.7293
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.1180, Validation Loss: 6.3450, Test Loss: 15.9996
Training RMSE: 5.932381156509197, Validation RMSE: 6.0792, Test RMSE: 15.2688
Training PCC: 0.9818642583480633, Validation PCC: 0.9816, Test PCC: 0.7371
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.9416, Validation Loss: 6.4194, Test Loss: 16.3606
Training RMSE: 5.764863700555613, Validation RMSE: 6.2452, Test RMSE: 15.7759
Training PCC: 0.9829534616659926, Validation PCC: 0.9825, Test PCC: 0.7134
Checkpoint saved for epoch 10
Total training time: 1180.98 seconds
loading best model from TeacherModel_RMSELoss_test_subject_1_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 15.9996, Test PCC: 0.7371, Test RMSE: 15.2688
Running training with subject_2 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_3_4_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_2. Resharding...
Processing subjects: ['subject_2'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_2/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_2/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_2_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 15.6154, Validation Loss: 8.5250, Test Loss: 21.8146
Training RMSE: 15.15053095614038, Validation RMSE: 8.1537, Test RMSE: 20.1676
Training PCC: 0.8091639911078893, Validation PCC: 0.9644, Test PCC: 0.6829
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 8.0207, Validation Loss: 7.5088, Test Loss: 20.8491
Training RMSE: 7.720260450995066, Validation RMSE: 7.2052, Test RMSE: 19.2540
Training PCC: 0.9678737892835304, Validation PCC: 0.9746, Test PCC: 0.6830
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 7.1869, Validation Loss: 6.9653, Test Loss: 21.3971
Training RMSE: 6.886152643982957, Validation RMSE: 6.6004, Test RMSE: 19.7339
Training PCC: 0.9749234040069373, Validation PCC: 0.9789, Test PCC: 0.6818
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 6.5744, Validation Loss: 6.2704, Test Loss: 21.5589
Training RMSE: 6.297179480393727, Validation RMSE: 5.9549, Test RMSE: 19.8471
Training PCC: 0.9790330809406192, Validation PCC: 0.9816, Test PCC: 0.6925
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 6.1801, Validation Loss: 5.8972, Test Loss: 21.4124
Training RMSE: 5.9271550239101645, Validation RMSE: 5.5941, Test RMSE: 19.5859
Training PCC: 0.9814392788484142, Validation PCC: 0.9825, Test PCC: 0.6882
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 5.8643, Validation Loss: 5.9736, Test Loss: 21.7822
Training RMSE: 5.629773455422099, Validation RMSE: 5.7603, Test RMSE: 20.0028
Training PCC: 0.9832911540870727, Validation PCC: 0.9837, Test PCC: 0.6887
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 5.6253, Validation Loss: 5.6650, Test Loss: 21.3975
Training RMSE: 5.412740729930924, Validation RMSE: 5.4316, Test RMSE: 19.6393
Training PCC: 0.9845096442194204, Validation PCC: 0.9857, Test PCC: 0.6840
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 5.2731, Validation Loss: 5.4001, Test Loss: 21.2703
Training RMSE: 5.082542426702453, Validation RMSE: 5.2143, Test RMSE: 19.6438
Training PCC: 0.9863575278033102, Validation PCC: 0.9860, Test PCC: 0.6787
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 5.1575, Validation Loss: 4.8859, Test Loss: 21.2318
Training RMSE: 4.982816446602829, Validation RMSE: 4.7145, Test RMSE: 19.4935
Training PCC: 0.9867741377624127, Validation PCC: 0.9877, Test PCC: 0.6829
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 4.9000, Validation Loss: 5.0179, Test Loss: 21.6445
Training RMSE: 4.745800206331702, Validation RMSE: 4.7618, Test RMSE: 19.8712
Training PCC: 0.9879398913462918, Validation PCC: 0.9884, Test PCC: 0.6867
Checkpoint saved for epoch 10
Total training time: 1184.87 seconds
loading best model from TeacherModel_RMSELoss_test_subject_2_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 21.2318, Test PCC: 0.6829, Test RMSE: 19.4935
Running training with subject_3 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_4_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_3. Resharding...
Processing subjects: ['subject_3'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_3/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_3/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_3_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.2169, Validation Loss: 9.4318, Test Loss: 21.5869
Training RMSE: 15.790685925057263, Validation RMSE: 9.1886, Test RMSE: 19.6871
Training PCC: 0.7976120253191664, Validation PCC: 0.9568, Test PCC: 0.7155
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 8.8519, Validation Loss: 8.3436, Test Loss: 22.5761
Training RMSE: 8.591369984111166, Validation RMSE: 8.0468, Test RMSE: 20.3758
Training PCC: 0.9611257544667781, Validation PCC: 0.9682, Test PCC: 0.7380
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 7.8853, Validation Loss: 7.6551, Test Loss: 21.6197
Training RMSE: 7.648738706014989, Validation RMSE: 7.4082, Test RMSE: 19.5503
Training PCC: 0.9694540720466457, Validation PCC: 0.9714, Test PCC: 0.6945
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.3071, Validation Loss: 7.0159, Test Loss: 21.3592
Training RMSE: 7.093117466302421, Validation RMSE: 6.7660, Test RMSE: 19.1130
Training PCC: 0.9737438657499005, Validation PCC: 0.9770, Test PCC: 0.7363
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 6.7921, Validation Loss: 6.9467, Test Loss: 22.6250
Training RMSE: 6.605108082536759, Validation RMSE: 6.7101, Test RMSE: 20.1737
Training PCC: 0.9771062112013821, Validation PCC: 0.9784, Test PCC: 0.7269
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 6.4332, Validation Loss: 6.3814, Test Loss: 21.3631
Training RMSE: 6.263396725906589, Validation RMSE: 6.2014, Test RMSE: 19.2944
Training PCC: 0.9794042958120199, Validation PCC: 0.9803, Test PCC: 0.7427
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.0953, Validation Loss: 6.2850, Test Loss: 21.0206
Training RMSE: 5.941999207667219, Validation RMSE: 6.0728, Test RMSE: 19.0098
Training PCC: 0.9813123414301025, Validation PCC: 0.9812, Test PCC: 0.7446
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 5.8807, Validation Loss: 6.5411, Test Loss: 21.1683
Training RMSE: 5.738241014684118, Validation RMSE: 6.3146, Test RMSE: 19.0336
Training PCC: 0.9825771081102331, Validation PCC: 0.9810, Test PCC: 0.7598
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 5.7469, Validation Loss: 6.1895, Test Loss: 21.5334
Training RMSE: 5.612358721775737, Validation RMSE: 6.0768, Test RMSE: 19.4836
Training PCC: 0.9831470000654635, Validation PCC: 0.9831, Test PCC: 0.7423
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.4260, Validation Loss: 5.9236, Test Loss: 20.5467
Training RMSE: 5.308974165257399, Validation RMSE: 5.7299, Test RMSE: 18.5261
Training PCC: 0.9848907201071198, Validation PCC: 0.9833, Test PCC: 0.7476
Checkpoint saved for epoch 10
Total training time: 1181.94 seconds
loading best model from TeacherModel_RMSELoss_test_subject_3_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 20.5467, Test PCC: 0.7476, Test RMSE: 18.5261
Running training with subject_4 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_4. Resharding...
Processing subjects: ['subject_4'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_4/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_4/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_4_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.4790, Validation Loss: 9.3908, Test Loss: 15.8018
Training RMSE: 15.929331517972955, Validation RMSE: 9.0334, Test RMSE: 15.0001
Training PCC: 0.7984039873818866, Validation PCC: 0.9569, Test PCC: 0.7451
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.2387, Validation Loss: 8.2884, Test Loss: 16.4552
Training RMSE: 8.822158677862324, Validation RMSE: 7.9671, Test RMSE: 15.6884
Training PCC: 0.960423366077206, Validation PCC: 0.9676, Test PCC: 0.7132
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.3045, Validation Loss: 7.6420, Test Loss: 15.8462
Training RMSE: 7.903227825048138, Validation RMSE: 7.2950, Test RMSE: 15.1939
Training PCC: 0.9688210274857344, Validation PCC: 0.9732, Test PCC: 0.7219
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.7493, Validation Loss: 7.1820, Test Loss: 16.3156
Training RMSE: 7.3932842834037125, Validation RMSE: 6.8934, Test RMSE: 15.6120
Training PCC: 0.9729781633668568, Validation PCC: 0.9752, Test PCC: 0.7234
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.3292, Validation Loss: 7.0645, Test Loss: 16.5374
Training RMSE: 7.009320995008058, Validation RMSE: 6.7861, Test RMSE: 15.8443
Training PCC: 0.9756309645438238, Validation PCC: 0.9770, Test PCC: 0.7293
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 6.9965, Validation Loss: 6.6323, Test Loss: 15.7238
Training RMSE: 6.695290629652297, Validation RMSE: 6.4173, Test RMSE: 14.9543
Training PCC: 0.9776385668191346, Validation PCC: 0.9801, Test PCC: 0.7329
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.6431, Validation Loss: 6.5273, Test Loss: 16.1064
Training RMSE: 6.379262191683025, Validation RMSE: 6.2704, Test RMSE: 15.2562
Training PCC: 0.9793630532544447, Validation PCC: 0.9807, Test PCC: 0.7570
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.4462, Validation Loss: 6.0580, Test Loss: 15.9866
Training RMSE: 6.203666939526403, Validation RMSE: 5.8394, Test RMSE: 15.1089
Training PCC: 0.98044749195904, Validation PCC: 0.9824, Test PCC: 0.7479
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.1478, Validation Loss: 5.9160, Test Loss: 16.1535
Training RMSE: 5.929111411690591, Validation RMSE: 5.7271, Test RMSE: 15.3280
Training PCC: 0.9819776454186209, Validation PCC: 0.9827, Test PCC: 0.7375
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.9228, Validation Loss: 5.8459, Test Loss: 16.2152
Training RMSE: 5.723137536909237, Validation RMSE: 5.6345, Test RMSE: 15.3057
Training PCC: 0.9830017061542402, Validation PCC: 0.9833, Test PCC: 0.7498
Checkpoint saved for epoch 10
Total training time: 1175.44 seconds
loading best model from TeacherModel_RMSELoss_test_subject_4_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 16.2152, Test PCC: 0.7498, Test RMSE: 15.3057
Running training with subject_5 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_5. Resharding...
Processing subjects: ['subject_5'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_5/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_5/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_5_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.5241, Validation Loss: 9.7091, Test Loss: 20.2167
Training RMSE: 16.015748993215453, Validation RMSE: 9.3088, Test RMSE: 19.3505
Training PCC: 0.8057504453743078, Validation PCC: 0.9544, Test PCC: 0.7175
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.4732, Validation Loss: 8.3411, Test Loss: 19.5347
Training RMSE: 9.05853148848274, Validation RMSE: 8.0135, Test RMSE: 18.8674
Training PCC: 0.9582183407253403, Validation PCC: 0.9659, Test PCC: 0.6960
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.4314, Validation Loss: 8.1117, Test Loss: 19.7831
Training RMSE: 8.073233969471628, Validation RMSE: 7.7549, Test RMSE: 18.7523
Training PCC: 0.9671862047597957, Validation PCC: 0.9703, Test PCC: 0.7077
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.0088, Validation Loss: 7.4311, Test Loss: 20.3898
Training RMSE: 7.656112803110169, Validation RMSE: 7.1090, Test RMSE: 19.6621
Training PCC: 0.9709450541103419, Validation PCC: 0.9745, Test PCC: 0.6740
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.3615, Validation Loss: 7.1297, Test Loss: 20.2256
Training RMSE: 7.073981445985225, Validation RMSE: 6.8170, Test RMSE: 19.3996
Training PCC: 0.9745822933157307, Validation PCC: 0.9764, Test PCC: 0.6960
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.0806, Validation Loss: 7.2984, Test Loss: 21.6869
Training RMSE: 6.814887267004824, Validation RMSE: 6.9722, Test RMSE: 20.6003
Training PCC: 0.9762145098548104, Validation PCC: 0.9765, Test PCC: 0.6596
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.7858, Validation Loss: 6.5156, Test Loss: 19.8352
Training RMSE: 6.540319967464327, Validation RMSE: 6.2469, Test RMSE: 18.9360
Training PCC: 0.9780280689280231, Validation PCC: 0.9790, Test PCC: 0.6917
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.3931, Validation Loss: 6.3874, Test Loss: 20.1532
Training RMSE: 6.187523051990039, Validation RMSE: 6.1364, Test RMSE: 19.1278
Training PCC: 0.98004651519138, Validation PCC: 0.9806, Test PCC: 0.6952
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.0795, Validation Loss: 6.1947, Test Loss: 20.2624
Training RMSE: 5.902674094860221, Validation RMSE: 5.9287, Test RMSE: 19.2112
Training PCC: 0.9816361713802454, Validation PCC: 0.9815, Test PCC: 0.6821
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.9865, Validation Loss: 6.1413, Test Loss: 20.4156
Training RMSE: 5.809235928377002, Validation RMSE: 5.8872, Test RMSE: 19.3390
Training PCC: 0.9824287983749221, Validation PCC: 0.9820, Test PCC: 0.6973
Checkpoint saved for epoch 10
Total training time: 1169.48 seconds
loading best model from TeacherModel_RMSELoss_test_subject_5_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 20.4156, Test PCC: 0.6973, Test RMSE: 19.3390
Running training with subject_6 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_6. Resharding...
Processing subjects: ['subject_6'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_6/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_6/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_6_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.4698, Validation Loss: 9.3669, Test Loss: 16.5699
Training RMSE: 15.988017686014253, Validation RMSE: 9.0850, Test RMSE: 15.2021
Training PCC: 0.7959279844772132, Validation PCC: 0.9532, Test PCC: 0.7447
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.2172, Validation Loss: 8.6030, Test Loss: 17.4844
Training RMSE: 8.866748390643577, Validation RMSE: 8.2945, Test RMSE: 15.5208
Training PCC: 0.9589027610438188, Validation PCC: 0.9654, Test PCC: 0.7650
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.2969, Validation Loss: 7.7143, Test Loss: 15.8052
Training RMSE: 7.955705422211469, Validation RMSE: 7.4366, Test RMSE: 14.5526
Training PCC: 0.9675543452943849, Validation PCC: 0.9695, Test PCC: 0.7650
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.8110, Validation Loss: 7.6384, Test Loss: 15.7449
Training RMSE: 7.488971527999009, Validation RMSE: 7.4033, Test RMSE: 14.7702
Training PCC: 0.9712246127070555, Validation PCC: 0.9717, Test PCC: 0.7667
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.3604, Validation Loss: 7.2460, Test Loss: 14.9858
Training RMSE: 7.094404624487326, Validation RMSE: 6.9817, Test RMSE: 13.8271
Training PCC: 0.9738473642686669, Validation PCC: 0.9744, Test PCC: 0.7883
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 6.9684, Validation Loss: 6.7710, Test Loss: 14.2660
Training RMSE: 6.715228725255019, Validation RMSE: 6.5211, Test RMSE: 13.0560
Training PCC: 0.9767217367224906, Validation PCC: 0.9768, Test PCC: 0.7864
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.5737, Validation Loss: 6.3626, Test Loss: 15.2775
Training RMSE: 6.35345791050089, Validation RMSE: 6.1689, Test RMSE: 14.0500
Training PCC: 0.9788308070598629, Validation PCC: 0.9786, Test PCC: 0.7706
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.4036, Validation Loss: 6.5866, Test Loss: 14.9227
Training RMSE: 6.204021476390886, Validation RMSE: 6.3282, Test RMSE: 13.7560
Training PCC: 0.9795206265640196, Validation PCC: 0.9780, Test PCC: 0.7919
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.1675, Validation Loss: 5.9370, Test Loss: 15.1555
Training RMSE: 5.9821510903719, Validation RMSE: 5.7415, Test RMSE: 13.8934
Training PCC: 0.9809695537672178, Validation PCC: 0.9812, Test PCC: 0.7842
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.9193, Validation Loss: 6.1597, Test Loss: 14.9572
Training RMSE: 5.748135452832632, Validation RMSE: 5.9232, Test RMSE: 13.6095
Training PCC: 0.9824264355320008, Validation PCC: 0.9808, Test PCC: 0.8052
Checkpoint saved for epoch 10
Total training time: 1175.11 seconds
loading best model from TeacherModel_RMSELoss_test_subject_6_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 15.1555, Test PCC: 0.7842, Test RMSE: 13.8934
Running training with subject_7 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_7. Resharding...
Processing subjects: ['subject_7'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_7/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_7/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_7_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.5176, Validation Loss: 9.6071, Test Loss: 16.3051
Training RMSE: 15.99421360095342, Validation RMSE: 9.2463, Test RMSE: 15.1767
Training PCC: 0.8003045178043949, Validation PCC: 0.9546, Test PCC: 0.7032
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.4197, Validation Loss: 8.6693, Test Loss: 17.1012
Training RMSE: 9.022385031711766, Validation RMSE: 8.3680, Test RMSE: 15.9340
Training PCC: 0.9580514028878352, Validation PCC: 0.9653, Test PCC: 0.7521
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.5714, Validation Loss: 8.0845, Test Loss: 15.0449
Training RMSE: 8.199295963698285, Validation RMSE: 7.7703, Test RMSE: 13.9289
Training PCC: 0.9660667603604866, Validation PCC: 0.9708, Test PCC: 0.7456
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.9604, Validation Loss: 7.3166, Test Loss: 14.9160
Training RMSE: 7.618469832873925, Validation RMSE: 6.9754, Test RMSE: 14.0981
Training PCC: 0.9704379299139129, Validation PCC: 0.9749, Test PCC: 0.7490
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.5153, Validation Loss: 7.1331, Test Loss: 15.4850
Training RMSE: 7.20561057934916, Validation RMSE: 6.8405, Test RMSE: 14.5469
Training PCC: 0.9733186074622928, Validation PCC: 0.9760, Test PCC: 0.7504
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.1478, Validation Loss: 6.8405, Test Loss: 15.3962
Training RMSE: 6.875588424079787, Validation RMSE: 6.5605, Test RMSE: 14.5178
Training PCC: 0.9755757755137872, Validation PCC: 0.9778, Test PCC: 0.7627
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.7592, Validation Loss: 6.5723, Test Loss: 15.4009
Training RMSE: 6.518903419981157, Validation RMSE: 6.3160, Test RMSE: 14.2848
Training PCC: 0.977591381662991, Validation PCC: 0.9795, Test PCC: 0.7657
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.4460, Validation Loss: 6.3004, Test Loss: 15.7624
Training RMSE: 6.231436367926559, Validation RMSE: 6.0619, Test RMSE: 14.6773
Training PCC: 0.9795048564489913, Validation PCC: 0.9809, Test PCC: 0.7680
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.2990, Validation Loss: 6.2545, Test Loss: 15.0942
Training RMSE: 6.091445062460938, Validation RMSE: 5.9991, Test RMSE: 14.2128
Training PCC: 0.9805966362370477, Validation PCC: 0.9815, Test PCC: 0.7600
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.0432, Validation Loss: 6.1576, Test Loss: 15.1301
Training RMSE: 5.850955836656617, Validation RMSE: 5.9492, Test RMSE: 14.2375
Training PCC: 0.9818596825140018, Validation PCC: 0.9816, Test PCC: 0.7853
Checkpoint saved for epoch 10
Total training time: 1162.52 seconds
loading best model from TeacherModel_RMSELoss_test_subject_7_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 15.1301, Test PCC: 0.7853, Test RMSE: 14.2375
Running training with subject_8 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_8. Resharding...
Processing subjects: ['subject_8'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_8/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_8/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_8_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.0114, Validation Loss: 9.8205, Test Loss: 13.6140
Training RMSE: 16.496233408528614, Validation RMSE: 9.4231, Test RMSE: 12.9938
Training PCC: 0.7896518018589518, Validation PCC: 0.9537, Test PCC: 0.7925
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.5617, Validation Loss: 8.3633, Test Loss: 13.0720
Training RMSE: 9.166731049859427, Validation RMSE: 8.0018, Test RMSE: 12.4666
Training PCC: 0.9582691696833335, Validation PCC: 0.9675, Test PCC: 0.7931
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.6979, Validation Loss: 8.1425, Test Loss: 14.4852
Training RMSE: 8.303037886212511, Validation RMSE: 7.7493, Test RMSE: 13.7080
Training PCC: 0.9667401016008489, Validation PCC: 0.9714, Test PCC: 0.8348
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.1066, Validation Loss: 7.9745, Test Loss: 11.8018
Training RMSE: 7.760810694558832, Validation RMSE: 7.6566, Test RMSE: 11.2671
Training PCC: 0.9707049262813889, Validation PCC: 0.9733, Test PCC: 0.8270
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.5938, Validation Loss: 7.2595, Test Loss: 13.2209
Training RMSE: 7.2803278650210155, Validation RMSE: 6.8956, Test RMSE: 12.4539
Training PCC: 0.9738947055181534, Validation PCC: 0.9764, Test PCC: 0.8242
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.1740, Validation Loss: 6.6935, Test Loss: 12.1284
Training RMSE: 6.901697740806797, Validation RMSE: 6.4487, Test RMSE: 11.5321
Training PCC: 0.9763202460461707, Validation PCC: 0.9792, Test PCC: 0.8238
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.8598, Validation Loss: 7.2167, Test Loss: 11.9190
Training RMSE: 6.610898676926527, Validation RMSE: 6.9350, Test RMSE: 11.4776
Training PCC: 0.9782767635875284, Validation PCC: 0.9769, Test PCC: 0.8091
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.6211, Validation Loss: 6.4002, Test Loss: 13.2231
Training RMSE: 6.399752365137503, Validation RMSE: 6.2053, Test RMSE: 12.7171
Training PCC: 0.97955272063727, Validation PCC: 0.9804, Test PCC: 0.7980
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.4195, Validation Loss: 6.3593, Test Loss: 12.2267
Training RMSE: 6.21255966996759, Validation RMSE: 6.1817, Test RMSE: 11.7713
Training PCC: 0.9805425204219785, Validation PCC: 0.9810, Test PCC: 0.8266
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.1018, Validation Loss: 5.9476, Test Loss: 12.9356
Training RMSE: 5.923448920492234, Validation RMSE: 5.7974, Test RMSE: 12.3715
Training PCC: 0.982133288412904, Validation PCC: 0.9823, Test PCC: 0.8373
Checkpoint saved for epoch 10
Total training time: 1163.70 seconds
loading best model from TeacherModel_RMSELoss_test_subject_8_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 12.9356, Test PCC: 0.8373, Test RMSE: 12.3715
Running training with subject_9 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_9. Resharding...
Processing subjects: ['subject_9'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_9/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_9/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_9_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.9523, Validation Loss: 10.2958, Test Loss: 11.5547
Training RMSE: 16.383855434452617, Validation RMSE: 9.9038, Test RMSE: 11.3228
Training PCC: 0.7985062994532047, Validation PCC: 0.9527, Test PCC: 0.7956
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.6085, Validation Loss: 9.2007, Test Loss: 10.6169
Training RMSE: 9.201054698568049, Validation RMSE: 8.7717, Test RMSE: 10.4159
Training PCC: 0.9585105933547678, Validation PCC: 0.9629, Test PCC: 0.8098
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.6465, Validation Loss: 8.7874, Test Loss: 10.5918
Training RMSE: 8.26518580438645, Validation RMSE: 8.5625, Test RMSE: 10.4013
Training PCC: 0.9670553205465353, Validation PCC: 0.9657, Test PCC: 0.7881
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 8.0005, Validation Loss: 8.2459, Test Loss: 11.5988
Training RMSE: 7.669544302835697, Validation RMSE: 8.0003, Test RMSE: 11.3847
Training PCC: 0.9714532421969103, Validation PCC: 0.9713, Test PCC: 0.7951
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.4796, Validation Loss: 7.5729, Test Loss: 10.8480
Training RMSE: 7.1855975132647565, Validation RMSE: 7.2662, Test RMSE: 10.6781
Training PCC: 0.9747688205495141, Validation PCC: 0.9755, Test PCC: 0.8083
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.1413, Validation Loss: 6.9043, Test Loss: 10.5903
Training RMSE: 6.871482143557169, Validation RMSE: 6.6542, Test RMSE: 10.3765
Training PCC: 0.9768803510616736, Validation PCC: 0.9779, Test PCC: 0.8056
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.7108, Validation Loss: 6.9343, Test Loss: 10.9642
Training RMSE: 6.483713909135602, Validation RMSE: 6.6799, Test RMSE: 10.7591
Training PCC: 0.9790927905780672, Validation PCC: 0.9784, Test PCC: 0.8096
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.5649, Validation Loss: 6.4643, Test Loss: 10.1902
Training RMSE: 6.35570917623799, Validation RMSE: 6.2609, Test RMSE: 10.0417
Training PCC: 0.9799509768312925, Validation PCC: 0.9806, Test PCC: 0.8276
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.1871, Validation Loss: 6.6200, Test Loss: 10.5466
Training RMSE: 5.999169681614976, Validation RMSE: 6.4791, Test RMSE: 10.3573
Training PCC: 0.9819560771508932, Validation PCC: 0.9793, Test PCC: 0.8154
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.0422, Validation Loss: 6.1267, Test Loss: 10.3064
Training RMSE: 5.852329934515605, Validation RMSE: 5.9612, Test RMSE: 10.1672
Training PCC: 0.9829623664425803, Validation PCC: 0.9820, Test PCC: 0.8096
Checkpoint saved for epoch 10
Total training time: 1161.77 seconds
loading best model from TeacherModel_RMSELoss_test_subject_9_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 10.3064, Test PCC: 0.8096, Test RMSE: 10.1672
Running training with subject_10 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_10. Resharding...
Processing subjects: ['subject_10'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_10/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_10/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_10_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.9154, Validation Loss: 10.1289, Test Loss: 17.0498
Training RMSE: 16.37224202572815, Validation RMSE: 9.6298, Test RMSE: 16.1391
Training PCC: 0.7924815843297681, Validation PCC: 0.9507, Test PCC: 0.7299
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.5036, Validation Loss: 8.9547, Test Loss: 16.5167
Training RMSE: 9.113862212595901, Validation RMSE: 8.4881, Test RMSE: 15.6874
Training PCC: 0.957985145534452, Validation PCC: 0.9643, Test PCC: 0.7108
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.4765, Validation Loss: 8.3736, Test Loss: 15.8181
Training RMSE: 8.122452670481147, Validation RMSE: 7.9694, Test RMSE: 15.1198
Training PCC: 0.9667663955584999, Validation PCC: 0.9689, Test PCC: 0.6984
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.9123, Validation Loss: 7.6241, Test Loss: 16.7151
Training RMSE: 7.594345976666706, Validation RMSE: 7.2683, Test RMSE: 15.9597
Training PCC: 0.9710161813058007, Validation PCC: 0.9727, Test PCC: 0.7013
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.5396, Validation Loss: 7.5437, Test Loss: 15.8497
Training RMSE: 7.240975147582652, Validation RMSE: 7.2318, Test RMSE: 15.1278
Training PCC: 0.973957560910328, Validation PCC: 0.9742, Test PCC: 0.7175
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.0854, Validation Loss: 6.8710, Test Loss: 14.7559
Training RMSE: 6.81921277492027, Validation RMSE: 6.5844, Test RMSE: 14.0082
Training PCC: 0.9764727330141735, Validation PCC: 0.9782, Test PCC: 0.7016
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.7791, Validation Loss: 6.6777, Test Loss: 16.1726
Training RMSE: 6.533263748738823, Validation RMSE: 6.4270, Test RMSE: 15.2966
Training PCC: 0.9782093056207307, Validation PCC: 0.9786, Test PCC: 0.7228
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.4471, Validation Loss: 6.4176, Test Loss: 14.7638
Training RMSE: 6.222024424773892, Validation RMSE: 6.1835, Test RMSE: 14.0172
Training PCC: 0.980328695000617, Validation PCC: 0.9805, Test PCC: 0.7223
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.1960, Validation Loss: 6.8069, Test Loss: 15.0816
Training RMSE: 6.002286039232239, Validation RMSE: 6.4682, Test RMSE: 14.2010
Training PCC: 0.9813852971600844, Validation PCC: 0.9799, Test PCC: 0.7348
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 5.9725, Validation Loss: 6.0570, Test Loss: 15.0771
Training RMSE: 5.790928961057973, Validation RMSE: 5.8725, Test RMSE: 14.0247
Training PCC: 0.9827536567674414, Validation PCC: 0.9825, Test PCC: 0.7530
Checkpoint saved for epoch 10
Total training time: 1159.70 seconds
loading best model from TeacherModel_RMSELoss_test_subject_10_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 15.0771, Test PCC: 0.7530, Test RMSE: 14.0247
Running training with subject_11 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_11. Resharding...
Processing subjects: ['subject_11'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_11/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_11/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_11_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.7143, Validation Loss: 10.2998, Test Loss: 15.0374
Training RMSE: 16.195703878635314, Validation RMSE: 9.7758, Test RMSE: 14.2873
Training PCC: 0.789380206845721, Validation PCC: 0.9524, Test PCC: 0.7203
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.4899, Validation Loss: 8.6439, Test Loss: 15.4233
Training RMSE: 9.09328179824643, Validation RMSE: 8.3172, Test RMSE: 14.2514
Training PCC: 0.9576099000509459, Validation PCC: 0.9642, Test PCC: 0.7393
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.5123, Validation Loss: 8.1582, Test Loss: 15.5612
Training RMSE: 8.1560966110811, Validation RMSE: 7.8560, Test RMSE: 14.2408
Training PCC: 0.9664356023780928, Validation PCC: 0.9692, Test PCC: 0.7462
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.8664, Validation Loss: 7.4866, Test Loss: 15.3220
Training RMSE: 7.553111523632111, Validation RMSE: 7.1752, Test RMSE: 14.0714
Training PCC: 0.9707536133673531, Validation PCC: 0.9723, Test PCC: 0.7701
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.4204, Validation Loss: 7.1599, Test Loss: 14.1021
Training RMSE: 7.124857502739604, Validation RMSE: 6.8865, Test RMSE: 12.8887
Training PCC: 0.9738479796301931, Validation PCC: 0.9736, Test PCC: 0.7827
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.0510, Validation Loss: 6.5108, Test Loss: 15.1681
Training RMSE: 6.788761886397029, Validation RMSE: 6.2917, Test RMSE: 13.6667
Training PCC: 0.9761019154902892, Validation PCC: 0.9781, Test PCC: 0.7820
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.7802, Validation Loss: 6.8352, Test Loss: 15.6324
Training RMSE: 6.536134589978349, Validation RMSE: 6.5640, Test RMSE: 14.2118
Training PCC: 0.9778100079934623, Validation PCC: 0.9776, Test PCC: 0.7660
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.6000, Validation Loss: 6.2987, Test Loss: 14.8394
Training RMSE: 6.376708813314515, Validation RMSE: 6.1705, Test RMSE: 13.8649
Training PCC: 0.9787496905054828, Validation PCC: 0.9803, Test PCC: 0.7737
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.1955, Validation Loss: 5.9707, Test Loss: 14.7844
Training RMSE: 6.003946015989878, Validation RMSE: 5.7891, Test RMSE: 13.6129
Training PCC: 0.9808835760269158, Validation PCC: 0.9815, Test PCC: 0.7603
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.1472, Validation Loss: 5.8935, Test Loss: 15.5477
Training RMSE: 5.952125151709812, Validation RMSE: 5.7354, Test RMSE: 14.2055
Training PCC: 0.9814490877701391, Validation PCC: 0.9819, Test PCC: 0.7608
Checkpoint saved for epoch 10
Total training time: 1154.82 seconds
loading best model from TeacherModel_RMSELoss_test_subject_11_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 15.5477, Test PCC: 0.7608, Test RMSE: 14.2055
Running training with subject_12 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_12. Resharding...
Processing subjects: ['subject_12'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_12/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_12/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_12_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 17.0411, Validation Loss: 10.5998, Test Loss: 11.9497
Training RMSE: 16.508472348616376, Validation RMSE: 10.0966, Test RMSE: 11.5336
Training PCC: 0.795107975520514, Validation PCC: 0.9485, Test PCC: 0.7871
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.4843, Validation Loss: 8.9440, Test Loss: 11.3120
Training RMSE: 9.081825774375018, Validation RMSE: 8.4786, Test RMSE: 10.7539
Training PCC: 0.9586733513596576, Validation PCC: 0.9642, Test PCC: 0.8292
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.5909, Validation Loss: 8.3785, Test Loss: 11.9380
Training RMSE: 8.205917079274249, Validation RMSE: 7.9553, Test RMSE: 11.3614
Training PCC: 0.9674158505128413, Validation PCC: 0.9699, Test PCC: 0.8376
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.9648, Validation Loss: 8.2475, Test Loss: 11.1826
Training RMSE: 7.60682445279951, Validation RMSE: 7.8075, Test RMSE: 10.8450
Training PCC: 0.9717566513967267, Validation PCC: 0.9706, Test PCC: 0.8405
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.5124, Validation Loss: 7.3488, Test Loss: 11.2765
Training RMSE: 7.204352772332789, Validation RMSE: 7.0130, Test RMSE: 10.9143
Training PCC: 0.9744449195430994, Validation PCC: 0.9752, Test PCC: 0.8207
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.0890, Validation Loss: 7.2660, Test Loss: 11.3855
Training RMSE: 6.82409594650191, Validation RMSE: 6.9432, Test RMSE: 10.8953
Training PCC: 0.9767735881127404, Validation PCC: 0.9764, Test PCC: 0.8449
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.7619, Validation Loss: 7.1915, Test Loss: 10.9664
Training RMSE: 6.510791364239483, Validation RMSE: 6.8512, Test RMSE: 10.5280
Training PCC: 0.9789549596314022, Validation PCC: 0.9771, Test PCC: 0.8535
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.5378, Validation Loss: 6.6933, Test Loss: 10.7467
Training RMSE: 6.3065310471910765, Validation RMSE: 6.3947, Test RMSE: 10.3845
Training PCC: 0.9801803366636911, Validation PCC: 0.9789, Test PCC: 0.8586
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.4202, Validation Loss: 6.5172, Test Loss: 10.6294
Training RMSE: 6.205825415568623, Validation RMSE: 6.2454, Test RMSE: 10.1548
Training PCC: 0.980714002431864, Validation PCC: 0.9806, Test PCC: 0.8697
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.0925, Validation Loss: 6.0659, Test Loss: 10.3535
Training RMSE: 5.89520741453985, Validation RMSE: 5.8657, Test RMSE: 10.0001
Training PCC: 0.9824221156928785, Validation PCC: 0.9822, Test PCC: 0.8681
Checkpoint saved for epoch 10
Total training time: 1153.69 seconds
loading best model from TeacherModel_RMSELoss_test_subject_12_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 10.3535, Test PCC: 0.8681, Test RMSE: 10.0001
Running training with subject_13 as the test subject.
Model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_fml
Sharded data not found at /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12. Resharding...
Processing subjects: ['subject_1', 'subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_1_2_3_4_5_6_7_8_9_10_11_12/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_13. Resharding...
Processing subjects: ['subject_13'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_13/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_13/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Running model: TeacherModel_RMSELoss_test_subject_13_wl100_ol75_fml
Starting from scratch.


Epoch 1/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 16.7414, Validation Loss: 9.5594, Test Loss: 16.5098
Training RMSE: 16.199328931366527, Validation RMSE: 9.2272, Test RMSE: 15.0789
Training PCC: 0.791093150364475, Validation PCC: 0.9531, Test PCC: 0.7376
Checkpoint saved for epoch 1


Epoch 2/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 9.5602, Validation Loss: 8.7685, Test Loss: 14.3847
Training RMSE: 9.155103964049642, Validation RMSE: 8.5136, Test RMSE: 13.2955
Training PCC: 0.9561906113871036, Validation PCC: 0.9632, Test PCC: 0.7865
Checkpoint saved for epoch 2


Epoch 3/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 8.6114, Validation Loss: 8.2213, Test Loss: 14.4294
Training RMSE: 8.233680050062938, Validation RMSE: 7.9087, Test RMSE: 13.4635
Training PCC: 0.9650457606167459, Validation PCC: 0.9673, Test PCC: 0.7894
Checkpoint saved for epoch 3


Epoch 4/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 7.9554, Validation Loss: 7.6418, Test Loss: 14.6543
Training RMSE: 7.615832449459448, Validation RMSE: 7.3238, Test RMSE: 13.5743
Training PCC: 0.9696387130649864, Validation PCC: 0.9724, Test PCC: 0.7998
Checkpoint saved for epoch 4


Epoch 5/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 7.5535, Validation Loss: 7.4231, Test Loss: 14.0742
Training RMSE: 7.251559171250197, Validation RMSE: 7.1631, Test RMSE: 12.9543
Training PCC: 0.9728532664095416, Validation PCC: 0.9733, Test PCC: 0.7846
Checkpoint saved for epoch 5


Epoch 6/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 7.1488, Validation Loss: 6.7587, Test Loss: 16.1569
Training RMSE: 6.878216783690259, Validation RMSE: 6.5390, Test RMSE: 14.2041
Training PCC: 0.9753788476172064, Validation PCC: 0.9775, Test PCC: 0.7835
Checkpoint saved for epoch 6


Epoch 7/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 6.8385, Validation Loss: 6.6831, Test Loss: 15.8371
Training RMSE: 6.598877098986773, Validation RMSE: 6.4749, Test RMSE: 14.0717
Training PCC: 0.9770830264918186, Validation PCC: 0.9770, Test PCC: 0.7799
Checkpoint saved for epoch 7


Epoch 8/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 6.6101, Validation Loss: 6.8228, Test Loss: 15.9375
Training RMSE: 6.382576210954325, Validation RMSE: 6.5932, Test RMSE: 14.3519
Training PCC: 0.9784929966169466, Validation PCC: 0.9766, Test PCC: 0.7644
Checkpoint saved for epoch 8


Epoch 9/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 6.3020, Validation Loss: 6.3388, Test Loss: 16.0884
Training RMSE: 6.102156278321414, Validation RMSE: 6.1273, Test RMSE: 14.4149
Training PCC: 0.9800981077500096, Validation PCC: 0.9799, Test PCC: 0.7721
Checkpoint saved for epoch 9


Epoch 10/10 Training:   0%|          | 0/328 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 6.1532, Validation Loss: 5.7785, Test Loss: 14.5541
Training RMSE: 5.9564639923533775, Validation RMSE: 5.6466, Test RMSE: 13.1497
Training PCC: 0.9812079233270335, Validation PCC: 0.9825, Test PCC: 0.7880
Checkpoint saved for epoch 10
Total training time: 1157.22 seconds
loading best model from TeacherModel_RMSELoss_test_subject_13_wl100_ol75_fml


  model.load_state_dict(torch.load(filename))


Test Loss: 14.5541, Test PCC: 0.7880, Test RMSE: 13.1497


In [10]:

average_best_rmse = np.mean(best_rmse_per_subject)
average_best_pcc = np.mean(best_pcc_per_subject)
print(f"Average of best RMSEs across all subjects: {average_best_rmse:.4f}")
print(f"Average of best PCCs across all subjects: {average_best_pcc:.4f}")
print(best_rmse_per_subject)
print(best_pcc_per_subject)

# subjects = [f'Subject {i+1}' for i in range(len(best_rmse_per_subject))]

# print(best_rmse_per_subject)
# # Plot a bar chart with subject labels on the x-axis
# plt.figure(figsize=(10, 6))
# plt.bar(subjects, best_rmse_per_subject, color='blue', edgecolor='black')
# plt.title('Best RMSEs for Each Subject')
# plt.xlabel('Subjects')
# plt.ylabel('Best RMSE')
# plt.xticks(rotation=45, ha='right')
# plt.grid(True, axis='y')
# plt.tight_layout()
# plt.show()

Average of best RMSEs across all subjects: 14.6141
Average of best PCCs across all subjects: 0.7693
[15.26879326502482, 19.49350889523824, 18.526126980781555, 15.30569863319397, 19.33898953596751, 13.89343241850535, 14.237481335798899, 12.371473332246145, 10.16721353928248, 14.024703880151113, 14.205533742904663, 10.00011557340622, 13.149747431278229]
[0.7371331750114339, 0.6829203879501827, 0.7476106279910449, 0.7497561992819648, 0.6973268480051121, 0.7841520540630218, 0.785341108427598, 0.8372599261195875, 0.809618142916082, 0.7529697806545882, 0.7608259305369893, 0.8681421267872186, 0.788034691497192]


In [11]:
import os
import zipfile
from datetime import datetime

notebook_name = 'regression_benchmark_fml'

# Create a timestamped folder name based on the notebook name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"{notebook_name}_checkpoints_{timestamp}"

# Make sure the folder exists
os.makedirs(folder_name, exist_ok=True)

checkpoint_dir = '.'

# Zip all checkpoint files and save in the new folder
zip_filename = f"{folder_name}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # List files only in the current directory (no subfolders)
    for file in os.listdir(checkpoint_dir):
        if "TeacherModel" in str(file):
          file_path = os.path.join(checkpoint_dir, file)
          zipf.write(file_path, os.path.relpath(file_path, checkpoint_dir))
          print(f"Checkpoint {file} has been added to the zip file.")
print(f"All checkpoints have been zipped and saved as {zip_filename}.")




Checkpoint TeacherModel_RMSELoss_test_subject_10_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_9_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_1_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_3_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_5_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_7_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_4_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_2_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_12_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_8_wl100_ol75_fml has been added to the zip file.
Checkpoint TeacherModel_RMSELoss_test_subject_11

In [12]:
# Download the zip file to your local machine
from google.colab import files
files.download(zip_filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>