In [41]:
!pip install ncps
#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Drive already mounted at /content/MyDrive; to attempt to forcibly remount, call drive.mount("/content/MyDrive", force_remount=True).


In [42]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.target_transforms = kwargs.get('target_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data_final.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


#copy h5 over then change data_folder_Name
import shutil
shutil.copy('/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data_final.h5', '/content/all_subjects_data_final.h5')
config.data_folder_name = '/content/all_subjects_data_final.h5'

In [43]:
import os
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split

class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, forecast_horizon, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, forecast horizon: {forecast_horizon}")

        self.window_length = window_length
        self.forecast_horizon = forecast_horizon

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created:", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                if subject_id not in h5_file:
                    print(f"Subject {subject_id} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_id]
                for session_id in subject_data.keys():
                    session_data_group = subject_data[session_id]
                    for session_speed in session_data_group.keys():
                        session_data = session_data_group[session_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into context-forecast pairs and save each pair
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_id, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        forecast_horizon = self.forecast_horizon
        step_size = window_size  # For forecasting, no overlap in pairs

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for context-forecast pairs based on total data length
            for i in range(0, total_data_length - (window_size + forecast_horizon) + 1, step_size):
                # Create context and forecast windows
                imu_context = imu_data[:, i:i + window_size]
                emg_context = emg_data[:, i:i + window_size]
                joint_context = joint_data[:, i:i + window_size]

                imu_forecast = imu_data[:, i + window_size:i + window_size + forecast_horizon]
                emg_forecast = emg_data[:, i + window_size:i + window_size + forecast_horizon]
                joint_forecast = joint_data[:, i + window_size:i + window_size + forecast_horizon]

                # Check if the window sizes are valid
                if (imu_context.shape[1] == window_size and
                    emg_context.shape[1] == window_size and
                    joint_context.shape[1] == window_size and
                    imu_forecast.shape[1] == forecast_horizon and
                    emg_forecast.shape[1] == forecast_horizon and
                    joint_forecast.shape[1] == forecast_horizon):

                    # Convert context and forecast data to pandas DataFrames
                    imu_df = pd.DataFrame(imu_context.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_context.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_context.T, columns=joint_columns)

                    imu_forecast_df = pd.DataFrame(imu_forecast.T, columns=[f"{col}_forecast" for col in imu_columns])
                    emg_forecast_df = pd.DataFrame(emg_forecast.T, columns=[f"{col}_forecast" for col in emg_columns])
                    joint_forecast_df = pd.DataFrame(joint_forecast.T, columns=[f"{col}_forecast" for col in joint_columns])

                    # Concatenate context and forecast data
                    combined_df = pd.concat([imu_df, emg_df, joint_df, imu_forecast_df, emg_forecast_df, joint_forecast_df], axis=1)

                    # Save the combined data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_fh{forecast_horizon}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
        extracted_data = []
        new_column_names = []

        if isinstance(session_data, h5py.Dataset):
            if session_data.dtype.names:
                # Compound dataset
                column_names = session_data.dtype.names
                for channel in channels:
                    if channel in column_names:
                        channel_data = session_data[channel][:]
                        channel_data = pd.to_numeric(channel_data, errors='coerce')
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                        new_column_names.append(channel)
                    else:
                        print(f"Channel {channel} not found in compound dataset.")
            else:
                # Simple dataset
                column_names = list(session_data.attrs.get('column_names', []))
                assert len(column_names) > 0, "column_names not found in dataset attributes"
                for channel in channels:
                    if channel in column_names:
                        col_idx = column_names.index(channel)
                        channel_data = session_data[:, col_idx]
                        channel_data = pd.to_numeric(channel_data, errors='coerce')
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                        new_column_names.append(channel)
                    else:
                        print(f"Channel {channel} not found in session data.")

        return np.array(extracted_data), new_column_names

class ForecastingDataset(Dataset):
    def __init__(self, config, subjects, window_length, forecast_horizon, split='train'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.forecast_horizon = forecast_horizon
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a path-safe string
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')
        dataset_name = f"dataset_wl{self.window_length}_fh{self.forecast_horizon}_{split}{subjects_str}"
        self.root_dir = os.path.join(self.config.dataset_root, dataset_name)

        # Ensure sharded data exists; if not, reshard
        self.ensure_resharded(subjects, dataset_name)

        # Load dataset index file
        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config, self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, forecast_horizon=self.forecast_horizon, dataset_name=dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir, self.split, self.data.iloc[idx, 0])

        if self.config.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError(f"Unsupported input format: {self.config.input_format}")

        x_acc, x_gyr, x_emg, target = self._extract_and_transform(combined_data)
        return x_acc, x_gyr, x_emg, target

    def _extract_and_transform(self, combined_data):
        # Extract the data for each modality for both context and target windows
        x_acc = self._extract_channels(combined_data, self.channels_imu_acc, is_context=True)
        x_gyr = self._extract_channels(combined_data, self.channels_imu_gyr, is_context=True)
        x_emg = self._extract_channels(combined_data, self.channels_emg, is_context=True)
        target = self._extract_channels(combined_data, self.channels_joints, is_context=False)

        # Apply necessary transforms
        x_acc = self.apply_transforms(x_acc, self.config.imu_transforms)
        x_gyr = self.apply_transforms(x_gyr, self.config.imu_transforms)
        x_emg = self.apply_transforms(x_emg, self.config.emg_transforms)
        target = self.apply_transforms(target, self.config.target_transforms)

        return x_acc, x_gyr, x_emg, target

    def _extract_channels(self, combined_data, channels, is_context=True):
        # Slice based on whether we're extracting context (input) or target (forecast output) frames
        length = self.window_length if is_context else self.forecast_horizon
        if is_context == False: #_forecast to end of channelnames
            channels = [channel + '_forecast' for channel in channels]

        return combined_data[channels].values[:length]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)



def create_forecasting_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=25,
    forecast_horizon=5,
    batch_size=64
):
    # Create datasets
    train_dataset = ForecastingDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        forecast_horizon=forecast_horizon,
        split='train',
    )

    test_dataset = ForecastingDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        forecast_horizon=forecast_horizon,
        split='test'
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output

class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100,forecast_horizon=5):
        super(teacher, self).__init__()

        self.w=w
        self.encoder_1_acc=Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr=Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg=Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc=Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr=Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg=Encoder_2(input_emg, drop_prob)

        self.BN_acc= nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr= nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg= nn.BatchNorm1d(input_emg, affine=False)


        self.fc = nn.Linear(2*3*128+128,3)
        self.dropout=nn.Dropout(p=.05)

        self.gate_1=GatingModule(128)
        self.gate_2=GatingModule(128)
        self.gate_3=GatingModule(128)

        self.fc_kd = nn.Linear(3*128, 2*128)

               # Define the gating network
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention=nn.MultiheadAttention(3*128,4,batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128*3, 3*128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2*3*128+128, 2*3*128+128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)
        self.pool = nn.AdaptiveAvgPool1d(forecast_horizon)


    def forward(self, x_acc, x_gyr, x_emg):

        x_acc_1=x_acc.view(x_acc.size(0)*x_acc.size(1),x_acc.size(-1))
        x_gyr_1=x_gyr.view(x_gyr.size(0)*x_gyr.size(1),x_gyr.size(-1))
        x_emg_1=x_emg.view(x_emg.size(0)*x_emg.size(1),x_emg.size(-1))

        x_acc_1=self.BN_acc(x_acc_1)
        x_gyr_1=self.BN_gyr(x_gyr_1)
        x_emg_1=self.BN_emg(x_emg_1)

        x_acc_2=x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2=x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2=x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Pass through Encoder 1 for each modality and capture hidden states
        x_acc_1, (h_acc_1, _) = self.encoder_1_acc(x_acc_2)
        x_gyr_1, (h_gyr_1, _) = self.encoder_1_gyr(x_gyr_2)
        x_emg_1, (h_emg_1, _) = self.encoder_1_emg(x_emg_2)

        # Pass through Encoder 2 for each modality and capture hidden states
        x_acc_2, (h_acc_2, _) = self.encoder_2_acc(x_acc_2)
        x_gyr_2, (h_gyr_2, _) = self.encoder_2_gyr(x_gyr_2)
        x_emg_2, (h_emg_2, _) = self.encoder_2_emg(x_emg_2)

        # x_acc=torch.cat((x_acc_1,x_acc_2),dim=-1)
        # x_gyr=torch.cat((x_gyr_1,x_gyr_2),dim=-1)
        # x_emg=torch.cat((x_emg_1,x_emg_2),dim=-1)

        x_acc=self.gate_1(x_acc_1,x_acc_2)
        x_gyr=self.gate_2(x_gyr_1,x_gyr_2)
        x_emg=self.gate_3(x_emg_1,x_emg_2)

        x=torch.cat((x_acc,x_gyr,x_emg),dim=-1)
        x_kd=self.fc_kd(x)


        out_1, attn_output_weights=self.attention(x,x,x)

        gating_weights = self.gating_net(x)
        out_2=gating_weights*x

        weights_1 = self.weighted_feat(x[:,:,0:128])
        weights_2 = self.weighted_feat(x[:,:,128:2*128])
        weights_3 = self.weighted_feat(x[:,:,2*128:3*128])
        x_1=weights_1*x[:,:,0:128]
        x_2=weights_2*x[:,:,128:2*128]
        x_3=weights_3*x[:,:,2*128:3*128]
        out_3=x_1+x_2+x_3

        out=torch.cat((out_1,out_2,out_3),dim=-1)

        gating_weights_1 = self.gating_net_1(out)
        out=gating_weights_1*out

        #pool dim 1 from w to forcast_horizon
        out = out.permute(0, 2, 1)  # Permute to (batch_size, channels, w) for 1D pooling
        out = self.pool(out)         # Pool to shape (batch_size, channels, forecast_horizon)
        out = out.permute(0, 2, 1)   # Permute back to (batch_size, forecast_horizon, channels)

        # Final fully connected layer for forecasting
        out = self.fc(out)

        #print(out.shape)
        return out, x_kd, (h_acc_1, h_acc_2, h_gyr_1, h_gyr_2, h_emg_1, h_emg_2)


In [45]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm

def train_forecasting_model(
    model,
    train_loader,
    val_loader,
    test_loader,  # Added test_loader
    num_epochs=50,
    learning_rate=1e-3,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    subject_id=None  # For logging purposes
):
    # Move model to the device
    model = model.to(device)

    # Define loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float('inf')

    for epoch in tqdm(range(num_epochs), desc=f"Subject {subject_id}"):
        model.train()
        train_losses = []
        start_time = time.time()

        for x_acc, x_gyr, x_emg, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Move data to device
            x_acc, x_gyr, x_emg, targets = x_acc.to(device), x_gyr.to(device), x_emg.to(device), targets.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(x_acc, x_gyr, x_emg)[0]  # [0] to get forecast_output

            # Compute MSE Loss
            mse_loss = criterion(outputs, targets)
            # Compute RMSE Loss
            rmse_loss = torch.sqrt(mse_loss + 1e-8)  # Add epsilon to avoid sqrt(0)
            train_losses.append(rmse_loss.item())

            # Backward pass and optimization
            rmse_loss.backward()
            optimizer.step()

        avg_train_loss = sum(train_losses) / len(train_losses)

        # Validation
        model.eval()
        val_losses = []
        with torch.no_grad():
            for x_acc, x_gyr, x_emg, targets in val_loader:
                x_acc, x_gyr, x_emg, targets = x_acc.to(device), x_gyr.to(device), x_emg.to(device), targets.to(device)

                outputs = model(x_acc, x_gyr, x_emg)[0]  # [0] to get forecast_output
                mse_loss = criterion(outputs, targets)
                rmse_loss = torch.sqrt(mse_loss + 1e-8)
                val_losses.append(rmse_loss.item())

        avg_val_loss = sum(val_losses) / len(val_losses)

        # Testing at each epoch
        test_losses = []
        with torch.no_grad():
            for x_acc, x_gyr, x_emg, targets in test_loader:
                x_acc, x_gyr, x_emg, targets = x_acc.to(device), x_gyr.to(device), x_emg.to(device), targets.to(device)

                outputs = model(x_acc, x_gyr, x_emg)[0]  # [0] to get forecast_output
                mse_loss = criterion(outputs, targets)
                rmse_loss = torch.sqrt(mse_loss + 1e-8)
                test_losses.append(rmse_loss.item())

        avg_test_loss = sum(test_losses) / len(test_losses)

        # Check if this is the best model so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # Save the best model
            torch.save(model.state_dict(), f'best_teacher_model_subject_{subject_id}.pth')

        end_time = time.time()
        epoch_time = end_time - start_time

        print(f"Subject {subject_id} | Epoch {epoch+1}/{num_epochs}, "
              f"Train RMSE: {avg_train_loss:.6f}, "
              f"Val RMSE: {avg_val_loss:.6f}, "
              f"Test RMSE: {avg_test_loss:.6f}, "
              f"Time: {epoch_time:.2f}s")


from scipy.stats import pearsonr
import numpy as np

def evaluate_model(model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model = model.to(device)
    model.eval()
    criterion = nn.MSELoss()
    test_losses = []
    all_predictions = []
    all_targets = []
    total_inference_time = 0.0
    total_samples = 0

    # Perform inference time measurement using a single example
    for x_acc, x_gyr, x_emg, targets in test_loader:
        x_acc, x_gyr, x_emg, targets = x_acc.to(device), x_gyr.to(device), x_emg.to(device), targets.to(device)
        # Select only the first sample from the batch (i.e., batch size = 1)
        x_acc_single = x_acc[0:1]
        x_gyr_single = x_gyr[0:1]
        x_emg_single = x_emg[0:1]

        # Measure inference time for a single sample
        start_time = time.time()
        model(x_acc_single, x_gyr_single, x_emg_single)[0]  # Only forward pass
        end_time = time.time()
        total_inference_time = end_time - start_time

        break  # Exit after one batch, as we only want one example for timing

    # Continue evaluation on the entire test set
    with torch.no_grad():
        for x_acc, x_gyr, x_emg, targets in test_loader:
            x_acc, x_gyr, x_emg, targets = x_acc.to(device), x_gyr.to(device), x_emg.to(device), targets.to(device)
            batch_size = x_acc.size(0)
            total_samples += batch_size

            outputs = model(x_acc, x_gyr, x_emg)[0]  # [0] to get forecast_output
            mse_loss = criterion(outputs, targets)
            rmse_loss = torch.sqrt(mse_loss + 1e-8)
            test_losses.append(rmse_loss.item())

            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    avg_test_rmse = sum(test_losses) / len(test_losses)

    # Concatenate all predictions and targets
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Reshape for PCC calculation
    pred_flat = all_predictions.reshape(-1, all_predictions.shape[-1])
    target_flat = all_targets.reshape(-1, all_targets.shape[-1])

    # Compute PCC for each output feature
    pcc_list = []
    for i in range(pred_flat.shape[1]):
        pred_col = pred_flat[:, i]
        target_col = target_flat[:, i]
        if np.std(pred_col) == 0 or np.std(target_col) == 0:
            pcc = 0
        else:
            pcc, _ = pearsonr(pred_col, target_col)
        pcc_list.append(pcc)

    avg_pcc = np.mean(pcc_list)

    # Return inference time for single sample, along with other metrics
    return {
        'rmse': avg_test_rmse,
        'pcc': avg_pcc,
        'inference_time_per_sample': total_inference_time  # For a single example
    }



In [46]:

def run_per_subject_training(
    subjects,
    config,
    window_length=100,
    forecast_horizon=10,
    batch_size=64,
    num_epochs=50,
    learning_rate=1e-3,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    results = []

    for test_subject in subjects:
        print(f"\nStarting training for test subject: {test_subject}")

        # Prepare training and test subjects
        train_subjects = [s for s in subjects if s != test_subject]
        test_subjects = [test_subject]

        # Create data loaders
        train_loader, val_loader, test_loader = create_forecasting_data_loaders(
            config=config,
            train_subjects=train_subjects,
            test_subjects=test_subjects,
            window_length=window_length,
            forecast_horizon=forecast_horizon,
            batch_size=batch_size
        )

        # Define input sizes for each modality and output size
        input_size_acc = len(config.channels_imu_acc)
        input_size_gyr = len(config.channels_imu_gyr)
        input_size_emg = len(config.channels_emg)
        output_size = len(config.channels_joints)

        # Instantiate the model
        model = teacher(input_acc=input_size_acc, input_gyr=input_size_gyr, input_emg=input_size_emg,
                        drop_prob=0.25, w=window_length, forecast_horizon=forecast_horizon)

        # Train the model
        train_forecasting_model(
            model,
            train_loader,
            val_loader,
            test_loader,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            device=device,
            subject_id=test_subject  # For logging and model saving
        )

        # Load the best model
        model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))

        # Evaluate on test set
        evaluation_results = evaluate_model(model, test_loader, device=device)

        # Append results
        results.append({
            'subject': test_subject,
            'rmse': evaluation_results['rmse'],
            'pcc': evaluation_results['pcc'],
            'inference_time_per_sample': evaluation_results['inference_time_per_sample']
        })

        print(f"Subject {test_subject} | Test RMSE: {evaluation_results['rmse']:.6f}, "
              f"Average PCC: {evaluation_results['pcc']:.6f}, "
              f"Inference Time per Sample: {evaluation_results['inference_time_per_sample'] * 1000:.6f} ms")

    return results


In [47]:
import random
import csv
# Define subjects
subjects = [f'subject_{i}' for i in range(1, 14)]  # subject_1 to subject_13

# Run per-subject training and evaluation
results = run_per_subject_training(
    subjects=subjects,
    config=config,
    window_length=25,
    forecast_horizon=5,
    batch_size=64,
    num_epochs=10,
    learning_rate=1e-3,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# Compile and display results
rmse_list = [res['rmse'] for res in results]
pcc_list = [res['pcc'] for res in results]
inference_time_list = [res['inference_time_per_sample'] for res in results]

avg_rmse = np.mean(rmse_list)
avg_pcc = np.mean(pcc_list)
avg_inference_time = np.mean(inference_time_list)

print("\nFinal Results:")
print(f"Average RMSE over all subjects: {avg_rmse:.6f}")
print(f"Average PCC over all subjects: {avg_pcc:.6f}")
print(f"Average Inference Time per Sample: {avg_inference_time * 1000:.6f} ms")

print("\nPer-Subject Results:")
for res in results:
    print(f"Subject {res['subject']} | RMSE: {res['rmse']:.6f}, "
          f"PCC: {res['pcc']:.6f}, "
          f"Inference Time per Sample: {res['inference_time_per_sample'] * 1000:.6f} ms")



Starting training for test subject: subject_1
Sharded data found at /content/datasets/dataset_wl25_fh5_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_1. Skipping resharding.


Subject subject_1:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 1/10, Train RMSE: 19.171364, Val RMSE: 14.494708, Test RMSE: 20.567913, Time: 151.93s


Epoch 2/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 2/10, Train RMSE: 12.435492, Val RMSE: 10.539838, Test RMSE: 23.059329, Time: 137.92s


Epoch 3/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 3/10, Train RMSE: 11.075372, Val RMSE: 10.119015, Test RMSE: 21.773305, Time: 140.36s


Epoch 4/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 4/10, Train RMSE: 10.215973, Val RMSE: 9.114155, Test RMSE: 21.408430, Time: 137.03s


Epoch 5/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 5/10, Train RMSE: 9.697767, Val RMSE: 10.317435, Test RMSE: 23.704689, Time: 139.07s


Epoch 6/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 6/10, Train RMSE: 9.408254, Val RMSE: 8.718238, Test RMSE: 21.778524, Time: 138.18s


Epoch 7/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 7/10, Train RMSE: 9.012429, Val RMSE: 8.405160, Test RMSE: 21.649477, Time: 136.87s


Epoch 8/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 8/10, Train RMSE: 8.835929, Val RMSE: 9.144761, Test RMSE: 22.408336, Time: 136.00s


Epoch 9/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 9/10, Train RMSE: 8.734861, Val RMSE: 7.906885, Test RMSE: 21.054964, Time: 137.17s


Epoch 10/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_1 | Epoch 10/10, Train RMSE: 8.515690, Val RMSE: 8.004246, Test RMSE: 22.229743, Time: 138.45s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_1 | Test RMSE: 21.054964, Average PCC: 0.715173, Inference Time per Sample: 6.265879 ms

Starting training for test subject: subject_2
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_2. Skipping resharding.


Subject subject_2:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 1/10, Train RMSE: 18.273591, Val RMSE: 11.191251, Test RMSE: 20.101556, Time: 139.60s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 2/10, Train RMSE: 11.065052, Val RMSE: 9.196958, Test RMSE: 21.298390, Time: 135.96s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 3/10, Train RMSE: 9.522953, Val RMSE: 8.362129, Test RMSE: 21.611745, Time: 138.69s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 4/10, Train RMSE: 8.852067, Val RMSE: 8.133346, Test RMSE: 21.732257, Time: 134.78s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 5/10, Train RMSE: 8.450616, Val RMSE: 7.850637, Test RMSE: 21.843512, Time: 136.06s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 6/10, Train RMSE: 8.098772, Val RMSE: 7.510735, Test RMSE: 20.658470, Time: 133.66s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 7/10, Train RMSE: 7.725703, Val RMSE: 7.088926, Test RMSE: 22.258711, Time: 136.56s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 8/10, Train RMSE: 7.527127, Val RMSE: 6.968439, Test RMSE: 21.419355, Time: 133.39s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 9/10, Train RMSE: 7.240020, Val RMSE: 6.991644, Test RMSE: 20.844432, Time: 133.45s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_2 | Epoch 10/10, Train RMSE: 7.206359, Val RMSE: 6.690824, Test RMSE: 21.485448, Time: 134.90s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_2 | Test RMSE: 21.485448, Average PCC: 0.774237, Inference Time per Sample: 5.954742 ms

Starting training for test subject: subject_3
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_3. Skipping resharding.


Subject subject_3:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 1/10, Train RMSE: 18.818811, Val RMSE: 12.348780, Test RMSE: 20.022148, Time: 134.57s


Epoch 2/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 2/10, Train RMSE: 11.854246, Val RMSE: 10.432132, Test RMSE: 19.509477, Time: 134.36s


Epoch 3/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 3/10, Train RMSE: 10.298779, Val RMSE: 9.440566, Test RMSE: 19.904495, Time: 133.14s


Epoch 4/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 4/10, Train RMSE: 9.575619, Val RMSE: 9.052277, Test RMSE: 20.458569, Time: 132.06s


Epoch 5/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 5/10, Train RMSE: 9.124593, Val RMSE: 8.421688, Test RMSE: 17.991384, Time: 135.21s


Epoch 6/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 6/10, Train RMSE: 8.649827, Val RMSE: 8.493268, Test RMSE: 19.668693, Time: 132.99s


Epoch 7/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 7/10, Train RMSE: 8.504377, Val RMSE: 7.969604, Test RMSE: 17.739606, Time: 134.31s


Epoch 8/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 8/10, Train RMSE: 7.986587, Val RMSE: 8.402258, Test RMSE: 20.093588, Time: 133.34s


Epoch 9/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 9/10, Train RMSE: 7.921865, Val RMSE: 7.670337, Test RMSE: 19.832941, Time: 135.49s


Epoch 10/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_3 | Epoch 10/10, Train RMSE: 7.669059, Val RMSE: 7.657033, Test RMSE: 18.575491, Time: 136.96s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_3 | Test RMSE: 18.575491, Average PCC: 0.908992, Inference Time per Sample: 6.140232 ms

Starting training for test subject: subject_4
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_4. Skipping resharding.


Subject subject_4:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 1/10, Train RMSE: 19.246544, Val RMSE: 12.397064, Test RMSE: 17.674303, Time: 136.14s


Epoch 2/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 2/10, Train RMSE: 12.148950, Val RMSE: 10.504853, Test RMSE: 17.105362, Time: 136.72s


Epoch 3/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 3/10, Train RMSE: 10.664303, Val RMSE: 9.728200, Test RMSE: 17.841046, Time: 132.74s


Epoch 4/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 4/10, Train RMSE: 9.847084, Val RMSE: 8.921660, Test RMSE: 16.627731, Time: 132.29s


Epoch 5/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 5/10, Train RMSE: 9.463365, Val RMSE: 8.605624, Test RMSE: 15.567592, Time: 134.67s


Epoch 6/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 6/10, Train RMSE: 9.164692, Val RMSE: 8.590695, Test RMSE: 16.346822, Time: 133.14s


Epoch 7/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 7/10, Train RMSE: 8.728640, Val RMSE: 8.044370, Test RMSE: 15.089831, Time: 134.90s


Epoch 8/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 8/10, Train RMSE: 8.402141, Val RMSE: 7.822094, Test RMSE: 15.820991, Time: 132.84s


Epoch 9/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 9/10, Train RMSE: 8.209670, Val RMSE: 8.088012, Test RMSE: 16.428210, Time: 132.31s


Epoch 10/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_4 | Epoch 10/10, Train RMSE: 8.047748, Val RMSE: 7.778774, Test RMSE: 17.078571, Time: 132.39s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_4 | Test RMSE: 17.078571, Average PCC: 0.859510, Inference Time per Sample: 6.427765 ms

Starting training for test subject: subject_5
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_5. Skipping resharding.


Subject subject_5:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 1/10, Train RMSE: 18.915635, Val RMSE: 11.955323, Test RMSE: 23.790073, Time: 132.05s


Epoch 2/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 2/10, Train RMSE: 12.210250, Val RMSE: 9.885386, Test RMSE: 20.919282, Time: 133.43s


Epoch 3/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 3/10, Train RMSE: 10.863565, Val RMSE: 9.581300, Test RMSE: 20.803595, Time: 131.36s


Epoch 4/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 4/10, Train RMSE: 10.161020, Val RMSE: 8.909139, Test RMSE: 22.271496, Time: 132.73s


Epoch 5/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 5/10, Train RMSE: 9.776022, Val RMSE: 8.663987, Test RMSE: 21.206813, Time: 133.36s


Epoch 6/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 6/10, Train RMSE: 9.350059, Val RMSE: 8.524952, Test RMSE: 22.327843, Time: 133.41s


Epoch 7/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 7/10, Train RMSE: 9.001303, Val RMSE: 8.087134, Test RMSE: 20.588575, Time: 133.01s


Epoch 8/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 8/10, Train RMSE: 8.816247, Val RMSE: 7.790647, Test RMSE: 20.375988, Time: 131.40s


Epoch 9/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 9/10, Train RMSE: 8.621832, Val RMSE: 8.073839, Test RMSE: 21.078646, Time: 133.32s


Epoch 10/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_5 | Epoch 10/10, Train RMSE: 8.413639, Val RMSE: 8.152039, Test RMSE: 21.025679, Time: 131.57s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_5 | Test RMSE: 20.375988, Average PCC: 0.809165, Inference Time per Sample: 7.849216 ms

Starting training for test subject: subject_6
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_6. Skipping resharding.


Subject subject_6:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 1/10, Train RMSE: 19.172291, Val RMSE: 12.229351, Test RMSE: 16.074682, Time: 134.10s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 2/10, Train RMSE: 12.120621, Val RMSE: 10.921658, Test RMSE: 14.580404, Time: 135.23s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 3/10, Train RMSE: 10.800318, Val RMSE: 9.393073, Test RMSE: 14.909665, Time: 134.82s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 4/10, Train RMSE: 10.043389, Val RMSE: 9.068144, Test RMSE: 14.064421, Time: 135.77s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 5/10, Train RMSE: 9.593221, Val RMSE: 8.638303, Test RMSE: 14.567468, Time: 133.43s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 6/10, Train RMSE: 9.327785, Val RMSE: 8.476028, Test RMSE: 13.470926, Time: 133.94s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 7/10, Train RMSE: 8.864541, Val RMSE: 8.061748, Test RMSE: 14.302997, Time: 134.58s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 8/10, Train RMSE: 8.618265, Val RMSE: 8.299452, Test RMSE: 14.426586, Time: 134.51s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 9/10, Train RMSE: 8.479409, Val RMSE: 8.084838, Test RMSE: 13.885561, Time: 138.86s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_6 | Epoch 10/10, Train RMSE: 8.331574, Val RMSE: 7.987618, Test RMSE: 14.255684, Time: 135.07s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_6 | Test RMSE: 14.255684, Average PCC: 0.893101, Inference Time per Sample: 6.012678 ms

Starting training for test subject: subject_7
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_7. Skipping resharding.


Subject subject_7:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 1/10, Train RMSE: 18.933155, Val RMSE: 11.763575, Test RMSE: 16.887420, Time: 136.95s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 2/10, Train RMSE: 12.131010, Val RMSE: 9.888094, Test RMSE: 16.203203, Time: 135.08s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 3/10, Train RMSE: 10.818535, Val RMSE: 9.181366, Test RMSE: 14.969517, Time: 134.17s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 4/10, Train RMSE: 10.242433, Val RMSE: 9.059613, Test RMSE: 14.215584, Time: 134.57s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 5/10, Train RMSE: 9.800006, Val RMSE: 8.766363, Test RMSE: 14.616564, Time: 135.35s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 6/10, Train RMSE: 9.523071, Val RMSE: 8.483809, Test RMSE: 15.512834, Time: 136.58s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 7/10, Train RMSE: 9.112221, Val RMSE: 8.001080, Test RMSE: 14.637954, Time: 134.16s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 8/10, Train RMSE: 8.842271, Val RMSE: 8.079410, Test RMSE: 16.519970, Time: 135.95s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 9/10, Train RMSE: 8.798220, Val RMSE: 7.813311, Test RMSE: 15.042783, Time: 135.71s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_7 | Epoch 10/10, Train RMSE: 8.562807, Val RMSE: 8.152844, Test RMSE: 14.509980, Time: 137.04s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_7 | Test RMSE: 15.042783, Average PCC: 0.891507, Inference Time per Sample: 6.182909 ms

Starting training for test subject: subject_8
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_7_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_8. Skipping resharding.


Subject subject_8:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 1/10, Train RMSE: 19.643013, Val RMSE: 12.755993, Test RMSE: 12.773328, Time: 136.45s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 2/10, Train RMSE: 12.337869, Val RMSE: 10.689375, Test RMSE: 12.066831, Time: 134.38s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 3/10, Train RMSE: 11.030770, Val RMSE: 9.205558, Test RMSE: 12.824245, Time: 136.44s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 4/10, Train RMSE: 10.242392, Val RMSE: 9.092811, Test RMSE: 13.449331, Time: 137.29s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 5/10, Train RMSE: 9.946437, Val RMSE: 8.884854, Test RMSE: 13.803953, Time: 137.23s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 6/10, Train RMSE: 9.461782, Val RMSE: 8.272432, Test RMSE: 13.166075, Time: 134.29s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 7/10, Train RMSE: 9.187524, Val RMSE: 8.530338, Test RMSE: 13.314613, Time: 136.05s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 8/10, Train RMSE: 9.030058, Val RMSE: 8.023698, Test RMSE: 13.709402, Time: 133.47s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 9/10, Train RMSE: 8.787740, Val RMSE: 7.817387, Test RMSE: 13.642688, Time: 135.33s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_8 | Epoch 10/10, Train RMSE: 8.612727, Val RMSE: 8.251192, Test RMSE: 14.200814, Time: 133.21s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_8 | Test RMSE: 13.642688, Average PCC: 0.906970, Inference Time per Sample: 7.717133 ms

Starting training for test subject: subject_9
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_7_8_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_9. Skipping resharding.


Subject subject_9:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 1/10, Train RMSE: 19.867991, Val RMSE: 12.379286, Test RMSE: 11.845587, Time: 135.54s


Epoch 2/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 2/10, Train RMSE: 12.409248, Val RMSE: 11.207534, Test RMSE: 12.048791, Time: 133.81s


Epoch 3/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 3/10, Train RMSE: 11.198617, Val RMSE: 10.161941, Test RMSE: 12.152711, Time: 134.40s


Epoch 4/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 4/10, Train RMSE: 10.378680, Val RMSE: 9.994681, Test RMSE: 13.362251, Time: 133.87s


Epoch 5/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 5/10, Train RMSE: 9.822023, Val RMSE: 9.438993, Test RMSE: 13.188891, Time: 133.76s


Epoch 6/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 6/10, Train RMSE: 9.394098, Val RMSE: 9.423314, Test RMSE: 11.618048, Time: 135.28s


Epoch 7/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 7/10, Train RMSE: 9.051916, Val RMSE: 8.531269, Test RMSE: 12.493288, Time: 133.96s


Epoch 8/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 8/10, Train RMSE: 8.808331, Val RMSE: 8.992820, Test RMSE: 14.249501, Time: 135.48s


Epoch 9/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 9/10, Train RMSE: 8.529702, Val RMSE: 8.725918, Test RMSE: 12.721659, Time: 133.28s


Epoch 10/10:   0%|          | 0/336 [00:00<?, ?it/s]

Subject subject_9 | Epoch 10/10, Train RMSE: 8.391430, Val RMSE: 8.219952, Test RMSE: 11.615757, Time: 136.03s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_9 | Test RMSE: 11.615757, Average PCC: 0.878473, Inference Time per Sample: 6.474495 ms

Starting training for test subject: subject_10
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_7_8_9_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_10. Skipping resharding.


Subject subject_10:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 1/10, Train RMSE: 19.393413, Val RMSE: 12.290198, Test RMSE: 18.526777, Time: 134.91s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 2/10, Train RMSE: 12.370706, Val RMSE: 11.294031, Test RMSE: 20.027293, Time: 135.03s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 3/10, Train RMSE: 10.936946, Val RMSE: 10.468950, Test RMSE: 22.540350, Time: 133.99s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 4/10, Train RMSE: 10.248633, Val RMSE: 9.966397, Test RMSE: 21.672740, Time: 134.67s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 5/10, Train RMSE: 9.970255, Val RMSE: 9.817432, Test RMSE: 18.003593, Time: 134.35s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 6/10, Train RMSE: 9.400988, Val RMSE: 9.138194, Test RMSE: 20.718144, Time: 133.88s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 7/10, Train RMSE: 9.102349, Val RMSE: 9.124029, Test RMSE: 19.208595, Time: 135.44s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 8/10, Train RMSE: 8.944503, Val RMSE: 8.814508, Test RMSE: 21.053342, Time: 133.21s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 9/10, Train RMSE: 8.643077, Val RMSE: 8.821949, Test RMSE: 21.053481, Time: 135.27s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_10 | Epoch 10/10, Train RMSE: 8.547196, Val RMSE: 8.568270, Test RMSE: 20.076435, Time: 132.37s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_10 | Test RMSE: 20.076435, Average PCC: 0.822939, Inference Time per Sample: 5.614042 ms

Starting training for test subject: subject_11
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_7_8_9_10_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_11. Skipping resharding.


Subject subject_11:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 1/10, Train RMSE: 19.282651, Val RMSE: 13.054902, Test RMSE: 14.770755, Time: 134.92s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 2/10, Train RMSE: 12.133132, Val RMSE: 10.886893, Test RMSE: 16.200890, Time: 133.23s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 3/10, Train RMSE: 10.836621, Val RMSE: 9.990338, Test RMSE: 17.052743, Time: 134.95s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 4/10, Train RMSE: 10.155690, Val RMSE: 9.738220, Test RMSE: 16.213529, Time: 133.56s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 5/10, Train RMSE: 9.596100, Val RMSE: 9.035604, Test RMSE: 15.205196, Time: 134.94s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 6/10, Train RMSE: 9.373622, Val RMSE: 9.035747, Test RMSE: 16.175097, Time: 133.77s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 7/10, Train RMSE: 8.968953, Val RMSE: 9.062647, Test RMSE: 15.967580, Time: 134.33s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 8/10, Train RMSE: 8.671138, Val RMSE: 8.290280, Test RMSE: 15.230612, Time: 133.14s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 9/10, Train RMSE: 8.513475, Val RMSE: 8.291398, Test RMSE: 16.147234, Time: 135.31s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_11 | Epoch 10/10, Train RMSE: 8.497495, Val RMSE: 8.333545, Test RMSE: 15.088476, Time: 133.19s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_11 | Test RMSE: 15.230612, Average PCC: 0.907642, Inference Time per Sample: 5.738497 ms

Starting training for test subject: subject_12
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_7_8_9_10_11_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_12. Skipping resharding.


Subject subject_12:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 1/10, Train RMSE: 19.334025, Val RMSE: 12.281630, Test RMSE: 14.838007, Time: 133.96s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 2/10, Train RMSE: 12.386503, Val RMSE: 10.447440, Test RMSE: 14.436039, Time: 135.18s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 3/10, Train RMSE: 11.001598, Val RMSE: 10.108748, Test RMSE: 16.001272, Time: 134.10s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 4/10, Train RMSE: 10.226398, Val RMSE: 8.902708, Test RMSE: 15.373815, Time: 135.64s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 5/10, Train RMSE: 9.721499, Val RMSE: 9.213695, Test RMSE: 16.089692, Time: 134.63s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 6/10, Train RMSE: 9.376140, Val RMSE: 8.249103, Test RMSE: 14.087020, Time: 134.82s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 7/10, Train RMSE: 9.113479, Val RMSE: 8.516533, Test RMSE: 15.353817, Time: 133.80s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 8/10, Train RMSE: 8.976480, Val RMSE: 9.068431, Test RMSE: 15.262710, Time: 136.31s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 9/10, Train RMSE: 8.792676, Val RMSE: 8.388378, Test RMSE: 15.401522, Time: 134.50s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_12 | Epoch 10/10, Train RMSE: 8.538753, Val RMSE: 8.543263, Test RMSE: 15.318564, Time: 134.75s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_12 | Test RMSE: 14.087020, Average PCC: 0.887538, Inference Time per Sample: 6.175518 ms

Starting training for test subject: subject_13
Sharded data found at /content/datasets/dataset_wl25_fh5_train_1_2_3_4_5_6_7_8_9_10_11_12. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl25_fh5_test_13. Skipping resharding.


Subject subject_13:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 1/10, Train RMSE: 19.210120, Val RMSE: 13.001581, Test RMSE: 19.826961, Time: 135.56s


Epoch 2/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 2/10, Train RMSE: 12.323207, Val RMSE: 11.187521, Test RMSE: 19.702522, Time: 132.97s


Epoch 3/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 3/10, Train RMSE: 10.967000, Val RMSE: 9.834995, Test RMSE: 19.986776, Time: 133.34s


Epoch 4/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 4/10, Train RMSE: 10.159137, Val RMSE: 9.294633, Test RMSE: 18.678962, Time: 134.62s


Epoch 5/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 5/10, Train RMSE: 9.645126, Val RMSE: 9.089143, Test RMSE: 18.955178, Time: 142.61s


Epoch 6/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 6/10, Train RMSE: 9.509816, Val RMSE: 8.728450, Test RMSE: 18.511132, Time: 142.49s


Epoch 7/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 7/10, Train RMSE: 8.969726, Val RMSE: 8.573413, Test RMSE: 18.094718, Time: 142.59s


Epoch 8/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 8/10, Train RMSE: 8.873005, Val RMSE: 8.334978, Test RMSE: 18.258981, Time: 144.02s


Epoch 9/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 9/10, Train RMSE: 8.581857, Val RMSE: 8.178218, Test RMSE: 19.136530, Time: 141.99s


Epoch 10/10:   0%|          | 0/337 [00:00<?, ?it/s]

Subject subject_13 | Epoch 10/10, Train RMSE: 8.473141, Val RMSE: 7.921524, Test RMSE: 18.605870, Time: 140.69s


  model.load_state_dict(torch.load(f'best_teacher_model_subject_{test_subject}.pth'))


Subject subject_13 | Test RMSE: 18.605870, Average PCC: 0.869990, Inference Time per Sample: 10.337830 ms

Final Results:
Average RMSE over all subjects: 17.009793
Average PCC over all subjects: 0.855787
Average Inference Time per Sample: 6.683918 ms

Per-Subject Results:
Subject subject_1 | RMSE: 21.054964, PCC: 0.715173, Inference Time per Sample: 6.265879 ms
Subject subject_2 | RMSE: 21.485448, PCC: 0.774237, Inference Time per Sample: 5.954742 ms
Subject subject_3 | RMSE: 18.575491, PCC: 0.908992, Inference Time per Sample: 6.140232 ms
Subject subject_4 | RMSE: 17.078571, PCC: 0.859510, Inference Time per Sample: 6.427765 ms
Subject subject_5 | RMSE: 20.375988, PCC: 0.809165, Inference Time per Sample: 7.849216 ms
Subject subject_6 | RMSE: 14.255684, PCC: 0.893101, Inference Time per Sample: 6.012678 ms
Subject subject_7 | RMSE: 15.042783, PCC: 0.891507, Inference Time per Sample: 6.182909 ms
Subject subject_8 | RMSE: 13.642688, PCC: 0.906970, Inference Time per Sample: 7.717133 ms

In [48]:
import os
import zipfile
from datetime import datetime

notebook_name = 'teacher_forecast_benchmark'

# Create a timestamped folder name based on the notebook name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"{notebook_name}_checkpoints_{timestamp}"

# Make sure the folder exists
os.makedirs(folder_name, exist_ok=True)

checkpoint_dir = '.'

# Zip all checkpoint files and save in the new folder
zip_filename = f"{folder_name}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # List files only in the current directory (no subfolders)
    for file in os.listdir(checkpoint_dir):
        if file.endswith('.pth'):  # Assuming your checkpoints are saved with a .pth extension
            file_path = os.path.join(checkpoint_dir, file)
            zipf.write(file_path, os.path.relpath(file_path, checkpoint_dir))

print(f"All checkpoints have been zipped and saved as {zip_filename}.")


# Download the zip file to your local machine
from google.colab import files
files.download(zip_filename)

All checkpoints have been zipped and saved as teacher_forecast_benchmark_checkpoints_20241018_132737.zip.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>