In [None]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Mounted at /content/MyDrive


In [None]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [None]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
        """
        Extracts data for the given channels from the dataset (whether it's a compound dataset or simple dataset),
        and interpolates missing values (NaNs) in each channel data.
        """
        extracted_data = []
        column_names = []

        if isinstance(session_data, h5py.Dataset):
            # Check if the dataset has named fields (compound dataset)
            if session_data.dtype.names:
                # Compound dataset, use the named fields
                column_names = session_data.dtype.names
                for channel in channels:
                    if channel in column_names:
                        channel_data = session_data[channel][:]  # Access by field name
                        # Convert the data to a numeric type (float), if necessary
                        channel_data = pd.to_numeric(channel_data, errors='coerce')
                        # Interpolate NaN values
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                    else:
                        print(f"Channel {channel} not found in compound dataset.")
            else:
                # Simple dataset, use index-based access (no named fields)
                column_names = session_data.attrs.get('column_names', [])

                # Cast column_names to a list to allow 'index' lookup
                column_names = list(column_names)
                new_column_names = []

                assert len(column_names) > 0, "column_names not found in dataset attributes"
                for channel in channels:
                    if channel in column_names:
                        col_idx = column_names.index(channel)
                        new_column_names.append(channel)
                        channel_data = session_data[:, col_idx]  # Access by column index

                        # Convert the data to a numeric type (float), if necessary
                        channel_data = pd.to_numeric(channel_data, errors='coerce')

                        # Interpolate NaN values
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                    else:
                        print(f"Channel {channel} not found in session data.")

        return np.array(extracted_data), new_column_names

In [None]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

class ImuJointPairSubjectDataset(ImuJointPairDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Create a mapping from subject strings (e.g., 'subject_1') to class indices
        self.subject_mapping = {subject: i for i, subject in enumerate(sorted(subjects))}

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx, 0]

        # Extract the subject ID from the filename
        filename_base = os.path.basename(filename)
        filename_without_ext = os.path.splitext(filename_base)[0]
        parts = filename_without_ext.split('_')

        # Construct subject string in the format 'subject_x'
        try:
            subject_index = parts.index('subject')
            subject_str = f"subject_{parts[subject_index + 1]}"  # Create string like 'subject_1'
        except ValueError:
            raise ValueError(f"'subject' not found in filename: {filename}")
        except IndexError:
            raise ValueError(f"Subject ID not found after 'subject' in filename: {filename}")

        # Map subject_str to class index
        if subject_str not in self.subject_mapping:
            raise ValueError(f"Subject ID {subject_str} not found in training set.")

        mapped_class = self.subject_mapping[subject_str]

        # Return class index instead of one-hot encoding
        return imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class

class ImuJointPairSubjectNormalizedDataset(ImuJointPairSubjectDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Compute normalization statistics per subject for IMU and EMG data
        self.normalization_stats = {subject: {'imu_acc': [], 'imu_gyr': [], 'emg': []} for subject in subjects}

        # Loop through the dataset once to gather data for each subject
        for idx in range(len(self.data)):
            filename = self.data.iloc[idx]['file_name']
            subject_str = next((subject for subject in subjects if subject in filename), None)
            if subject_str:
                imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class = super().__getitem__(idx)
                self.normalization_stats[subject_str]['imu_acc'].append(imu_data_acc)
                self.normalization_stats[subject_str]['imu_gyr'].append(imu_data_gyr)
                self.normalization_stats[subject_str]['emg'].append(emg_data)

        # Compute mean and std for each subject
        for subject, data in self.normalization_stats.items():
            if data['imu_acc']:
                imu_acc_data = torch.stack(data['imu_acc'])
                imu_gyr_data = torch.stack(data['imu_gyr'])
                emg_data = torch.stack(data['emg'])

                self.normalization_stats[subject] = {
                    'imu_acc_mean': imu_acc_data.mean(dim=0),
                    'imu_acc_std': imu_acc_data.std(dim=0),
                    'imu_gyr_mean': imu_gyr_data.mean(dim=0),
                    'imu_gyr_std': imu_gyr_data.std(dim=0),
                    'emg_mean': emg_data.mean(dim=0),
                    'emg_std': emg_data.std(dim=0)
                }

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx]['file_name']

        # Extract the subject ID from the filename
        subject_str = next((subject for subject in self.normalization_stats.keys() if subject in filename), None)
        if not subject_str:
            raise ValueError(f"Normalization stats not found for subject in filename: {filename}")

        # Apply normalization for each subject separately
        stats = self.normalization_stats[subject_str]
        imu_data_acc = (imu_data_acc - stats['imu_acc_mean']) / (stats['imu_acc_std'] + 1e-8)
        imu_data_gyr = (imu_data_gyr - stats['imu_gyr_mean']) / (stats['imu_gyr_std'] + 1e-8)
        emg_data = (emg_data - stats['emg_mean']) / (stats['emg_std'] + 1e-8)

        # Return normalized IMU data, joint data, EMG data, and class index
        return imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class
def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairSubjectNormalizedDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairSubjectNormalizedDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=0,  # Typically no overlap in test set
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader




In [None]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used


from torch.autograd import Function

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_grl):
        ctx.lambda_grl = lambda_grl
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        lambda_grl = ctx.lambda_grl
        grad_input = grad_output.neg() * lambda_grl
        return grad_input, None

def GradientReversalLayer(lambda_grl):
    return GradientReversalFunction.apply

class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100,num_subjects=12,lambda_grl=1.0):
        super(teacher, self).__init__()

        self.lambda_grl = lambda_grl
        self.w=w
        self.encoder_1_acc=Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr=Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg=Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc=Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr=Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg=Encoder_2(input_emg, drop_prob)

        self.BN_acc= nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr= nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg= nn.BatchNorm1d(input_emg, affine=False)


        self.fc = nn.Linear(2*3*128+128,3)
        self.fc_classifier = nn.Linear(2*3*128+128,num_subjects)
        self.dropout=nn.Dropout(p=.05)

        self.gate_1=GatingModule(128)
        self.gate_2=GatingModule(128)
        self.gate_3=GatingModule(128)

        self.fc_kd = nn.Linear(3*128, 2*128)

               # Define the gating network
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention=nn.MultiheadAttention(3*128,4,batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128*3, 3*128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2*3*128+128, 2*3*128+128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)


    def forward(self, x_acc, x_gyr, x_emg):

        x_acc_1=x_acc.view(x_acc.size(0)*x_acc.size(1),x_acc.size(-1))
        x_gyr_1=x_gyr.view(x_gyr.size(0)*x_gyr.size(1),x_gyr.size(-1))
        x_emg_1=x_emg.view(x_emg.size(0)*x_emg.size(1),x_emg.size(-1))

        x_acc_1=self.BN_acc(x_acc_1)
        x_gyr_1=self.BN_gyr(x_gyr_1)
        x_emg_1=self.BN_emg(x_emg_1)

        x_acc_2=x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2=x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2=x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Pass through Encoder 1 for each modality and capture hidden states
        x_acc_1, (h_acc_1, _) = self.encoder_1_acc(x_acc_2)
        x_gyr_1, (h_gyr_1, _) = self.encoder_1_gyr(x_gyr_2)
        x_emg_1, (h_emg_1, _) = self.encoder_1_emg(x_emg_2)

        # Pass through Encoder 2 for each modality and capture hidden states
        x_acc_2, (h_acc_2, _) = self.encoder_2_acc(x_acc_2)
        x_gyr_2, (h_gyr_2, _) = self.encoder_2_gyr(x_gyr_2)
        x_emg_2, (h_emg_2, _) = self.encoder_2_emg(x_emg_2)

        # x_acc=torch.cat((x_acc_1,x_acc_2),dim=-1)
        # x_gyr=torch.cat((x_gyr_1,x_gyr_2),dim=-1)
        # x_emg=torch.cat((x_emg_1,x_emg_2),dim=-1)

        x_acc=self.gate_1(x_acc_1,x_acc_2)
        x_gyr=self.gate_2(x_gyr_1,x_gyr_2)
        x_emg=self.gate_3(x_emg_1,x_emg_2)

        x=torch.cat((x_acc,x_gyr,x_emg),dim=-1)
        x_kd=self.fc_kd(x)


        out_1, attn_output_weights=self.attention(x,x,x)

        gating_weights = self.gating_net(x)
        out_2=gating_weights*x

        weights_1 = self.weighted_feat(x[:,:,0:128])
        weights_2 = self.weighted_feat(x[:,:,128:2*128])
        weights_3 = self.weighted_feat(x[:,:,2*128:3*128])
        x_1=weights_1*x[:,:,0:128]
        x_2=weights_2*x[:,:,128:2*128]
        x_3=weights_3*x[:,:,2*128:3*128]
        out_3=x_1+x_2+x_3

        out=torch.cat((out_1,out_2,out_3),dim=-1)

        gating_weights_1 = self.gating_net_1(out)
        out=gating_weights_1*out

        #print shape of out
        out_task=self.fc(out)

        grl = GradientReversalLayer(self.lambda_grl)
        out_rev = grl(out, self.lambda_grl)  # Apply GRL to reverse gradients

        # Domain classification with reversed gradients
        out_rev = out_rev.permute(0, 2, 1)  # Permute to [batch_size, features, time]
        out_rev = torch.nn.functional.adaptive_avg_pool1d(out_rev, 1)  # Pooling
        out_rev = out_rev.squeeze(-1)  # Shape: [batch_size, features]

        out_classifier = self.fc_classifier(out_rev)  # Domain classifier output


        #print(out.shape)
        return out_task, out_classifier, (h_acc_1, h_acc_2, h_gyr_1, h_gyr_2, h_emg_1, h_emg_2)




In [None]:
# @title Loss Functions
import statistics

class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

def compute_biomechanical_loss(predicted_angles):
    """
    Compute the biomechanical loss (L_bio) to enforce joint limits
    for the three joints, using index-based access.
    """
    # Define joint limits for each channel (in degrees)
    min_limits = torch.tensor([0, -90, -120], device=predicted_angles.device)
    max_limits = torch.tensor([150, 180, 90], device=predicted_angles.device)

    while min_limits.dim() < predicted_angles.dim():
        min_limits = min_limits.unsqueeze(0)
    while max_limits.dim() < predicted_angles.dim():
        max_limits = max_limits.unsqueeze(0)

    # Now min_limits and max_limits have shape [1, 1, num_joints] and will broadcast correctly
    lower_violation = torch.relu(min_limits - predicted_angles)
    upper_violation = torch.relu(predicted_angles - max_limits)

    L_bio = torch.mean(lower_violation + upper_violation)

    return L_bio


class BoundRmseLoss(nn.Module):
    def __init__(self, lambda_bio=0.1):
        super(BoundRmseLoss, self).__init__()
        self.lambda_bio = lambda_bio
        self.rmse_loss = RMSELoss()

    def forward(self, output, target):
        # Compute RMSE loss
        L_data = self.rmse_loss(output, target)

        # Compute biomechanical loss
        L_bio = compute_biomechanical_loss(output)

        # Total loss
        total_loss = L_data + self.lambda_bio * L_bio

        return total_loss

In [None]:
# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG,_) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse



def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(
    device,
    train_loader,
    val_loader,
    test_loader,
    learn_rate,
    epochs,
    model,
    filename,
    loss_function,
    optimizer=None,
    l1_lambda=None,
    train_from_last_epoch=False,
    curriculum_loader=None,
    num_classes=12,  # Add num_classes parameter
    alpha=1.0,       # Weighting factor for classification loss
):
    model.to(device)
    criterion_regression = loss_function
    criterion_classification = nn.CrossEntropyLoss(ignore_index=-1)

    if optimizer is None:
        # Create a default Adam optimizer if none is passed
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    train_pccs_channelwise = []
    val_pccs_channelwise = []
    test_pccs_channelwise = []

    train_rmses_channelwise = []
    val_rmses_channelwise = []
    test_rmses_channelwise = []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    if train_from_last_epoch and os.path.exists(checkpoint_path):
        # Scan for the latest saved checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch number
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']  # Continue from the next epoch

            # Load the history from checkpoint
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
            train_pccs = checkpoint['history']['train_pccs']
            val_pccs = checkpoint['history']['val_pccs']
            test_pccs = checkpoint['history']['test_pccs']
            train_rmses = checkpoint['history']['train_rmses']
            val_rmses = checkpoint['history']['val_rmses']
            test_rmses = checkpoint['history']['test_rmses']
            train_pccs_channelwise = checkpoint['history']['train_pccs_channelwise']
            val_pccs_channelwise = checkpoint['history']['val_pccs_channelwise']
            test_pccs_channelwise = checkpoint['history']['test_pccs_channelwise']
            train_rmses_channelwise = checkpoint['history']['train_rmses_channelwise']
            val_rmses_channelwise = checkpoint['history']['val_rmses_channelwise']
            test_rmses_channelwise = checkpoint['history']['test_rmses_channelwise']

            if 'curriculum_schedule' in checkpoint:
                curriculum_loader.curriculum_schedule = checkpoint['curriculum_schedule']  # Load saved curriculum schedule
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(last_epoch, epochs):
        epoch_start_time = time.time()
        model.train()

        if curriculum_loader:
            curriculum_loader.update_epoch(epoch)
            train_loader, val_loader, test_loader = curriculum_loader.get_loaders()

        # Track total loss
        epoch_train_total_loss = 0.0
        epoch_train_regression_loss = 0.0
        epoch_train_classification_loss = 0.0

        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        for i, (data_acc, data_gyr, target, data_EMG, subject_id) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            #print dimensions of each and subject id
            # print("acc",data_acc.shape)
            # print("gyro",data_gyr.shape)
            # print("target",target.shape)
            # print("emg",data_EMG.shape)
            # print("subject ids",subject_id)

            # Move data to device
            data_acc = data_acc.to(device).float()
            data_gyr = data_gyr.to(device).float()
            data_EMG = data_EMG.to(device).float()
            target = target.to(device).float()
            subject_id = subject_id.to(device).long() - 1  # Adjust subject IDs to start from 0

            # Forward pass
            output, out_classifier, _ = model(data_acc, data_gyr, data_EMG)

            # print(f"Output shape: {output.shape}")
            # print(f"Target shape: {target.shape}")
            # Compute regression loss
            loss_regression = criterion_regression(output, target)

            # Compute classification loss
            # Ensure shapes are correct
            # print(f"Out Classifier shape: {out_classifier.shape}")
            # print(f"Subject ID shape: {subject_id.shape}")
            # print(f"Subject ID dtype: {subject_id.dtype}")
            # print(f"Out Classifier dtype: {out_classifier.dtype}")
            # print(f"subject ids",subject_id)
            # print(f"out classifier",out_classifier)
            loss_classification = criterion_classification(out_classifier, subject_id)

            # Combine losses
            total_loss = loss_regression + alpha * loss_classification

            # Apply L1 regularization if specified
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss += l1_lambda * l1_norm

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()



            # Detach tensors and move to CPU to prevent issues with gradient computation
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(
                output.detach().cpu().numpy(),
                target.detach().cpu().numpy(),
                len(config.channels_joints),
                print_losses=False
            )

            # Update metrics
            epoch_train_total_loss += total_loss.item()
            epoch_train_regression_loss += loss_regression.item()
            epoch_train_classification_loss += loss_classification.item()

            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        # Calculate average losses
        avg_train_total_loss = epoch_train_total_loss / len(train_loader)
        avg_train_regression_loss = epoch_train_regression_loss / len(train_loader)
        avg_train_classification_loss = epoch_train_classification_loss / len(train_loader)

        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_total_loss)
        train_pccs.append(np.mean(avg_train_pcc))  # Overall average PCC
        train_rmses.append(np.mean(avg_train_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        train_pccs_channelwise.append(avg_train_pcc)  # Per channel
        train_rmses_channelwise.append(avg_train_rmse)  # Per channel

        # Evaluate on validation set
        avg_val_total_loss, avg_val_pcc, avg_val_rmse = evaluate_model(
            device,
            model,
            val_loader,
            criterion_regression
        )

        val_losses.append(avg_val_total_loss)
        val_pccs.append(np.mean(avg_val_pcc))  # Overall average PCC
        val_rmses.append(np.mean(avg_val_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        val_pccs_channelwise.append(avg_val_pcc)  # Per channel
        val_rmses_channelwise.append(avg_val_rmse)  # Per channel

        # Evaluate on test set
        avg_test_total_loss, avg_test_pcc, avg_test_rmse = evaluate_model(
            device,
            model,
            test_loader,
            criterion_regression
        )

        test_losses.append(avg_test_total_loss)
        test_pccs.append(np.mean(avg_test_pcc))  # Overall average PCC
        test_rmses.append(np.mean(avg_test_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        test_pccs_channelwise.append(avg_test_pcc)  # Per channel
        test_rmses_channelwise.append(avg_test_rmse)  # Per channel

        print(f"Epoch: {epoch + 1}, Training Total Loss: {avg_train_total_loss:.4f}, Validation Total Loss: {avg_val_total_loss:.4f}, Test Total Loss: {avg_test_total_loss:.4f}")
        print(f"Training Regression Loss: {avg_train_regression_loss:.4f}, Validation Regression Loss: {avg_val_total_loss:.4f}, Test Regression Loss: {avg_test_total_loss:.4f}")
        print(f"Training Classification Loss: {avg_train_classification_loss:.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse):.4f}, Validation RMSE: {np.mean(avg_val_rmse):.4f}, Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc):.4f}, Validation PCC: {np.mean(avg_val_pcc):.4f}, Test PCC: {np.mean(avg_test_pcc):.4f}")

        # Save checkpoint, including curriculum schedule
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save checkpoint with the curriculum schedule
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'test_losses': test_losses,
            'train_pccs': train_pccs,
            'val_pccs': val_pccs,
            'test_pccs': test_pccs,
            'train_rmses': train_rmses,
            'val_rmses': val_rmses,
            'test_rmses': test_rmses,
            'train_pccs_channelwise': train_pccs_channelwise,
            'val_pccs_channelwise': val_pccs_channelwise,
            'test_pccs_channelwise': test_pccs_channelwise,
            'train_rmses_channelwise': train_rmses_channelwise,
            'val_rmses_channelwise': val_rmses_channelwise,
            'test_rmses_channelwise': test_rmses_channelwise,
            'train_classification_loss': avg_train_classification_loss
        }

        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth",
            train_loss=avg_train_total_loss,
            val_loss=avg_val_total_loss,
            test_loss=avg_test_total_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            },
            history=history,  # Save history in the checkpoint
            curriculum_schedule=curriculum_loader.curriculum_schedule if curriculum_loader else None  # Save curriculum schedule
        )

        # Early stopping logic
        if avg_val_total_loss < best_val_loss:
            best_val_loss = avg_val_total_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses







In [None]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100,num_subjects=None,lambda_grl=None):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w,num_subjects=num_subjects,lambda_grl=lambda_grl)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model


In [None]:

import os
import h5py
import csv
from tqdm.notebook import tqdm
import pandas as pd
def create_curriculum_schedule(all_subjects, num_subjects, epochs):
    return [(epoch, sorted(random.sample(all_subjects, num_subjects))) for epoch in range(epochs)]

# Define the list of all subjects
all_subjects = ['subject_1','subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6',
                'subject_7', 'subject_8', 'subject_9', 'subject_10',
                'subject_11', 'subject_12', 'subject_13']

# Define the base model and save its weights
input_acc, input_gyr, input_emg = 18, 18, 3  # Example inputs based on your setup
base_model = teacher(input_acc, input_gyr, input_emg)
base_weights_path = 'base_teacher_weights.pth'
torch.save(base_model.state_dict(), base_weights_path)
window_overlap = 0
batch_size = 64
curriculum_epochs = 20

# Create model configurations for each subject as the test subject
model_configs = {}

test_subject = 'subject_1'
train_subjects = [subject for subject in all_subjects if subject not in test_subject]

# Define the number of subjects/classes
num_classes = 12

# List of different weights for the classification loss
alpha_values = [0.1, 0.2, 0.3, 0.4, 0.5]
lambda_values = [1.25, 1.5, 2]#[0.5, 0.75, 1, 1.25, 1.5, 2]
for lamda in lambda_values:
  for alpha in alpha_values:
      model_name = f'TeacherModel_DomainInvariant_alpha_{alpha}_lambda_{lamda}_wl{100}_ol{75}_normalizedbysubject'
      model_configs[model_name] = {
          'model': create_teacher_model(
              input_acc=input_acc,
              input_gyr=input_gyr,
              input_emg=input_emg,
              w=100,
              num_subjects=num_classes,  # Ensure the model includes the classification head
              lambda_grl=lamda
          ),
          'loss': RMSELoss(),
          'loaders': create_base_data_loaders(
              config=config,
              train_subjects=train_subjects,
              test_subjects=[test_subject],
              window_length=100,
              window_overlap=75,
              batch_size=batch_size
          ),
          'epochs': curriculum_epochs,
          'use_curriculum': False,
          'alpha': alpha,  # Weighting factor for classification loss
          'num_classes': num_classes,  # Number of classes for the classification task
      }

# Output the model configurations for debugging (optional)
for model_name, model_conf in model_configs.items():
    print(f"Model: {model_name}")


Sharded data not found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Resharding...
Processing subjects: ['subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13'] with window length: 100, overlap: 75
Dataset folder: /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13/train
Dataset folder created:  /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13/train


Processing subjects:   0%|          | 0/12 [00:00<?, ?it/s]

Sharded data not found at /content/datasets/dataset_wl100_ol0_test_1. Resharding...
Processing subjects: ['subject_1'] with window length: 100, overlap: 0
Dataset folder: /content/datasets/dataset_wl100_ol0_test_1/test
Dataset folder created:  /content/datasets/dataset_wl100_ol0_test_1/test


Processing subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping re

In [10]:
 # @title run models

#clear gpu memory
torch.cuda.empty_cache()

def ask_run():
    response = input("Do you want to run models? (yes/no): ").strip().lower()
    if response in ['yes', 'y']:
        return True
    elif response in ['no', 'n']:
        return False
    else:
        print("Invalid input. Please enter 'yes' or 'no'.")
        return ask_run()  # Recursively ask again until valid input is given

# run = ask_run()

if True:
    for model_name, model_config in model_configs.items():
        model = model_config['model']
        loss_function = model_config['loss']  # May be None, as loss functions are defined within train_teacher
        epochs = model_config.get("epochs", 100)
        device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        learn_rate = model_config.get("learn_rate", 0.001)
        use_curriculum = model_config.get("use_curriculum", False)

        optimizer = model_config.get("optimizer", None)
        l1_lambda = model_config.get("l1_lambda", None)
        alpha = model_config.get('alpha', 1.0)
        num_classes = model_config.get('num_classes', 12)  # Default to 12 if not specified

        print(f"Running model: {model_name} with alpha: {alpha}")

        if use_curriculum:
            curriculum_loader = model_config['loader']  # Get the CurriculumDataLoader
            train_loader, val_loader, test_loader = None, None, None  # Curriculum will handle loading per epoch

            model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, \
            train_rmses, val_rmses, test_rmses = train_teacher(
                device=device,
                train_loader=train_loader,  # Placeholders; curriculum will manage loaders dynamically
                val_loader=val_loader,
                test_loader=test_loader,
                learn_rate=learn_rate,
                epochs=epochs,
                model=model,
                filename=model_name,
                loss_function=loss_function,
                curriculum_loader=curriculum_loader,  # Pass the curriculum loader here
                optimizer=optimizer,
                l1_lambda=l1_lambda,
                train_from_last_epoch=model_config.get("train_from_last_epoch", False),
                alpha=alpha,
                num_classes=num_classes,
            )
        else:
            # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
            train_loader, val_loader, test_loader = model_config['loaders']

            model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, \
            train_rmses, val_rmses, test_rmses = train_teacher(
                device=device,
                train_loader=train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                learn_rate=learn_rate,
                epochs=epochs,
                model=model,
                filename=model_name,
                loss_function=loss_function,
                optimizer=optimizer,
                l1_lambda=l1_lambda,
                train_from_last_epoch=model_config.get("train_from_last_epoch", False),
                alpha=alpha,
                num_classes=num_classes,
            )

        print(f"Finished training for {model_name}.")


Running model: TeacherModel_DomainInvariant_alpha_0.1_lambda_1.25_wl100_ol75_normalizedbysubject with alpha: 0.1
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.3024, Validation Total Loss: 12.3323, Test Total Loss: 15.3634
Training Regression Loss: 19.0535, Validation Regression Loss: 12.3323, Test Regression Loss: 15.3634
Training Classification Loss: 2.4882
Training RMSE: 18.5180, Validation RMSE: 12.0035, Test RMSE: 14.4695
Training PCC: 0.8010, Validation PCC: 0.9373, Test PCC: 0.6947
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.2478, Validation Total Loss: 9.2493, Test Total Loss: 13.9701
Training Regression Loss: 11.0312, Validation Regression Loss: 9.2493, Test Regression Loss: 13.9701
Training Classification Loss: 2.1667
Training RMSE: 10.5885, Validation RMSE: 8.8420, Test RMSE: 13.1484
Training PCC: 0.9500, Validation PCC: 0.9659, Test PCC: 0.7214
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.6366, Validation Total Loss: 8.2320, Test Total Loss: 13.0337
Training Regression Loss: 9.4400, Validation Regression Loss: 8.2320, Test Regression Loss: 13.0337
Training Classification Loss: 1.9658
Training RMSE: 9.0431, Validation RMSE: 7.8772, Test RMSE: 12.0698
Training PCC: 0.9640, Validation PCC: 0.9731, Test PCC: 0.7189
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 8.8821, Validation Total Loss: 7.8287, Test Total Loss: 13.9621
Training Regression Loss: 8.6912, Validation Regression Loss: 7.8287, Test Regression Loss: 13.9621
Training Classification Loss: 1.9094
Training RMSE: 8.3115, Validation RMSE: 7.3719, Test RMSE: 12.9730
Training PCC: 0.9698, Validation PCC: 0.9773, Test PCC: 0.7436
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.3202, Validation Total Loss: 7.3121, Test Total Loss: 14.1122
Training Regression Loss: 8.1312, Validation Regression Loss: 7.3121, Test Regression Loss: 14.1122
Training Classification Loss: 1.8899
Training RMSE: 7.7742, Validation RMSE: 6.9216, Test RMSE: 12.8407
Training PCC: 0.9735, Validation PCC: 0.9801, Test PCC: 0.7617
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 7.7869, Validation Total Loss: 6.9639, Test Total Loss: 13.0292
Training Regression Loss: 7.5919, Validation Regression Loss: 6.9639, Test Regression Loss: 13.0292
Training Classification Loss: 1.9503
Training RMSE: 7.2527, Validation RMSE: 6.5591, Test RMSE: 12.2580
Training PCC: 0.9766, Validation PCC: 0.9819, Test PCC: 0.7571
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.4981, Validation Total Loss: 6.8988, Test Total Loss: 13.2028
Training Regression Loss: 7.2988, Validation Regression Loss: 6.8988, Test Regression Loss: 13.2028
Training Classification Loss: 1.9931
Training RMSE: 6.9693, Validation RMSE: 6.5542, Test RMSE: 12.4252
Training PCC: 0.9784, Validation PCC: 0.9822, Test PCC: 0.7828
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.2452, Validation Total Loss: 6.4037, Test Total Loss: 13.0583
Training Regression Loss: 7.0377, Validation Regression Loss: 6.4037, Test Regression Loss: 13.0583
Training Classification Loss: 2.0755
Training RMSE: 6.7286, Validation RMSE: 6.1333, Test RMSE: 11.9017
Training PCC: 0.9796, Validation PCC: 0.9848, Test PCC: 0.7513
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 6.8706, Validation Total Loss: 5.8304, Test Total Loss: 13.5668
Training Regression Loss: 6.6541, Validation Regression Loss: 5.8304, Test Regression Loss: 13.5668
Training Classification Loss: 2.1653
Training RMSE: 6.3816, Validation RMSE: 5.5523, Test RMSE: 12.4425
Training PCC: 0.9816, Validation PCC: 0.9865, Test PCC: 0.7722
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 6.6818, Validation Total Loss: 5.5726, Test Total Loss: 13.4445
Training Regression Loss: 6.4585, Validation Regression Loss: 5.5726, Test Regression Loss: 13.4445
Training Classification Loss: 2.2335
Training RMSE: 6.2004, Validation RMSE: 5.3802, Test RMSE: 12.3214
Training PCC: 0.9828, Validation PCC: 0.9866, Test PCC: 0.7834
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.3694, Validation Total Loss: 5.4191, Test Total Loss: 14.1164
Training Regression Loss: 6.1365, Validation Regression Loss: 5.4191, Test Regression Loss: 14.1164
Training Classification Loss: 2.3292
Training RMSE: 5.9016, Validation RMSE: 5.2065, Test RMSE: 12.6671
Training PCC: 0.9841, Validation PCC: 0.9884, Test PCC: 0.7792
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.1725, Validation Total Loss: 6.0604, Test Total Loss: 13.8726
Training Regression Loss: 5.9395, Validation Regression Loss: 6.0604, Test Regression Loss: 13.8726
Training Classification Loss: 2.3293
Training RMSE: 5.7273, Validation RMSE: 5.9084, Test RMSE: 13.0249
Training PCC: 0.9852, Validation PCC: 0.9882, Test PCC: 0.7348
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.0712, Validation Total Loss: 5.2367, Test Total Loss: 14.1934
Training Regression Loss: 5.8357, Validation Regression Loss: 5.2367, Test Regression Loss: 14.1934
Training Classification Loss: 2.3547
Training RMSE: 5.6314, Validation RMSE: 5.0866, Test RMSE: 13.0065
Training PCC: 0.9856, Validation PCC: 0.9890, Test PCC: 0.7893
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 5.9914, Validation Total Loss: 5.6654, Test Total Loss: 13.3341
Training Regression Loss: 5.7472, Validation Regression Loss: 5.6654, Test Regression Loss: 13.3341
Training Classification Loss: 2.4420
Training RMSE: 5.5476, Validation RMSE: 5.4145, Test RMSE: 12.3435
Training PCC: 0.9860, Validation PCC: 0.9882, Test PCC: 0.7890
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 5.6009, Validation Total Loss: 5.4930, Test Total Loss: 13.4048
Training Regression Loss: 5.3557, Validation Regression Loss: 5.4930, Test Regression Loss: 13.4048
Training Classification Loss: 2.4518
Training RMSE: 5.1960, Validation RMSE: 5.3256, Test RMSE: 12.2892
Training PCC: 0.9877, Validation PCC: 0.9890, Test PCC: 0.7787
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 5.6461, Validation Total Loss: 5.0135, Test Total Loss: 12.9195
Training Regression Loss: 5.4002, Validation Regression Loss: 5.0135, Test Regression Loss: 12.9195
Training Classification Loss: 2.4586
Training RMSE: 5.2253, Validation RMSE: 4.8451, Test RMSE: 11.6232
Training PCC: 0.9877, Validation PCC: 0.9897, Test PCC: 0.8001
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.4050, Validation Total Loss: 5.4279, Test Total Loss: 14.0730
Training Regression Loss: 5.1561, Validation Regression Loss: 5.4279, Test Regression Loss: 14.0730
Training Classification Loss: 2.4892
Training RMSE: 5.0018, Validation RMSE: 5.1891, Test RMSE: 12.9358
Training PCC: 0.9885, Validation PCC: 0.9899, Test PCC: 0.7462
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.2513, Validation Total Loss: 4.6120, Test Total Loss: 14.0304
Training Regression Loss: 4.9961, Validation Regression Loss: 4.6120, Test Regression Loss: 14.0304
Training Classification Loss: 2.5514
Training RMSE: 4.8679, Validation RMSE: 4.4309, Test RMSE: 12.7285
Training PCC: 0.9893, Validation PCC: 0.9917, Test PCC: 0.7613
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.1004, Validation Total Loss: 4.9855, Test Total Loss: 14.4837
Training Regression Loss: 4.8484, Validation Regression Loss: 4.9855, Test Regression Loss: 14.4837
Training Classification Loss: 2.5198
Training RMSE: 4.7248, Validation RMSE: 4.8864, Test RMSE: 13.2570
Training PCC: 0.9900, Validation PCC: 0.9905, Test PCC: 0.7660
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.1143, Validation Total Loss: 4.7802, Test Total Loss: 13.6747
Training Regression Loss: 4.8545, Validation Regression Loss: 4.7802, Test Regression Loss: 13.6747
Training Classification Loss: 2.5980
Training RMSE: 4.7243, Validation RMSE: 4.5991, Test RMSE: 12.2993
Training PCC: 0.9899, Validation PCC: 0.9913, Test PCC: 0.7889
Checkpoint saved for epoch 20
Total training time: 2575.88 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.1_lambda_1.25_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.2_lambda_1.25_wl100_ol75_normalizedbysubject with alpha: 0.2
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.5539, Validation Total Loss: 12.4039, Test Total Loss: 13.9937
Training Regression Loss: 19.0498, Validation Regression Loss: 12.4039, Test Regression Loss: 13.9937
Training Classification Loss: 2.5204
Training RMSE: 18.4908, Validation RMSE: 11.8791, Test RMSE: 13.1748
Training PCC: 0.8045, Validation PCC: 0.9353, Test PCC: 0.7261
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.2524, Validation Total Loss: 10.1029, Test Total Loss: 12.5496
Training Regression Loss: 10.7877, Validation Regression Loss: 10.1029, Test Regression Loss: 12.5496
Training Classification Loss: 2.3234
Training RMSE: 10.3340, Validation RMSE: 9.5508, Test RMSE: 11.9287
Training PCC: 0.9533, Validation PCC: 0.9628, Test PCC: 0.7337
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.8810, Validation Total Loss: 8.9816, Test Total Loss: 13.8310
Training Regression Loss: 9.4240, Validation Regression Loss: 8.9816, Test Regression Loss: 13.8310
Training Classification Loss: 2.2850
Training RMSE: 9.0005, Validation RMSE: 8.5261, Test RMSE: 13.0938
Training PCC: 0.9650, Validation PCC: 0.9692, Test PCC: 0.7085
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.1146, Validation Total Loss: 8.2628, Test Total Loss: 13.2620
Training Regression Loss: 8.6714, Validation Regression Loss: 8.2628, Test Regression Loss: 13.2620
Training Classification Loss: 2.2162
Training RMSE: 8.2757, Validation RMSE: 7.8252, Test RMSE: 12.4482
Training PCC: 0.9705, Validation PCC: 0.9732, Test PCC: 0.6835
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.5382, Validation Total Loss: 8.0867, Test Total Loss: 13.7214
Training Regression Loss: 8.0834, Validation Regression Loss: 8.0867, Test Regression Loss: 13.7214
Training Classification Loss: 2.2737
Training RMSE: 7.7111, Validation RMSE: 7.5977, Test RMSE: 12.8066
Training PCC: 0.9743, Validation PCC: 0.9742, Test PCC: 0.7008
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.0461, Validation Total Loss: 7.3476, Test Total Loss: 12.3339
Training Regression Loss: 7.5654, Validation Regression Loss: 7.3476, Test Regression Loss: 12.3339
Training Classification Loss: 2.4033
Training RMSE: 7.2319, Validation RMSE: 6.9846, Test RMSE: 11.6101
Training PCC: 0.9769, Validation PCC: 0.9788, Test PCC: 0.7243
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.8517, Validation Total Loss: 6.9325, Test Total Loss: 13.6244
Training Regression Loss: 7.3577, Validation Regression Loss: 6.9325, Test Regression Loss: 13.6244
Training Classification Loss: 2.4699
Training RMSE: 7.0493, Validation RMSE: 6.5797, Test RMSE: 12.5682
Training PCC: 0.9782, Validation PCC: 0.9802, Test PCC: 0.7101
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.5535, Validation Total Loss: 6.9885, Test Total Loss: 13.1280
Training Regression Loss: 7.0399, Validation Regression Loss: 6.9885, Test Regression Loss: 13.1280
Training Classification Loss: 2.5679
Training RMSE: 6.7378, Validation RMSE: 6.6640, Test RMSE: 12.1236
Training PCC: 0.9797, Validation PCC: 0.9801, Test PCC: 0.7362
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.3028, Validation Total Loss: 6.5496, Test Total Loss: 13.0524
Training Regression Loss: 6.7813, Validation Regression Loss: 6.5496, Test Regression Loss: 13.0524
Training Classification Loss: 2.6075
Training RMSE: 6.4804, Validation RMSE: 6.1692, Test RMSE: 12.0314
Training PCC: 0.9812, Validation PCC: 0.9816, Test PCC: 0.7255
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.0363, Validation Total Loss: 6.6353, Test Total Loss: 14.6160
Training Regression Loss: 6.5227, Validation Regression Loss: 6.6353, Test Regression Loss: 14.6160
Training Classification Loss: 2.5681
Training RMSE: 6.2653, Validation RMSE: 6.2061, Test RMSE: 13.2759
Training PCC: 0.9827, Validation PCC: 0.9823, Test PCC: 0.7288
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.7494, Validation Total Loss: 6.5974, Test Total Loss: 14.4305
Training Regression Loss: 6.2214, Validation Regression Loss: 6.5974, Test Regression Loss: 14.4305
Training Classification Loss: 2.6402
Training RMSE: 5.9834, Validation RMSE: 6.2142, Test RMSE: 13.2043
Training PCC: 0.9839, Validation PCC: 0.9829, Test PCC: 0.7062
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.7292, Validation Total Loss: 6.2117, Test Total Loss: 12.9570
Training Regression Loss: 6.1978, Validation Regression Loss: 6.2117, Test Regression Loss: 12.9570
Training Classification Loss: 2.6570
Training RMSE: 5.9625, Validation RMSE: 5.8693, Test RMSE: 11.7800
Training PCC: 0.9842, Validation PCC: 0.9843, Test PCC: 0.7747
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.4060, Validation Total Loss: 5.7619, Test Total Loss: 13.1478
Training Regression Loss: 5.8782, Validation Regression Loss: 5.7619, Test Regression Loss: 13.1478
Training Classification Loss: 2.6388
Training RMSE: 5.6788, Validation RMSE: 5.4882, Test RMSE: 11.7621
Training PCC: 0.9856, Validation PCC: 0.9846, Test PCC: 0.7755
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.2222, Validation Total Loss: 5.8477, Test Total Loss: 13.4865
Training Regression Loss: 5.6698, Validation Regression Loss: 5.8477, Test Regression Loss: 13.4865
Training Classification Loss: 2.7622
Training RMSE: 5.4920, Validation RMSE: 5.6261, Test RMSE: 12.3428
Training PCC: 0.9865, Validation PCC: 0.9856, Test PCC: 0.7514
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.0776, Validation Total Loss: 5.7636, Test Total Loss: 13.5806
Training Regression Loss: 5.5367, Validation Regression Loss: 5.7636, Test Regression Loss: 13.5806
Training Classification Loss: 2.7046
Training RMSE: 5.3544, Validation RMSE: 5.4658, Test RMSE: 12.1745
Training PCC: 0.9872, Validation PCC: 0.9858, Test PCC: 0.7683
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.0514, Validation Total Loss: 5.7412, Test Total Loss: 13.3004
Training Regression Loss: 5.4981, Validation Regression Loss: 5.7412, Test Regression Loss: 13.3004
Training Classification Loss: 2.7668
Training RMSE: 5.3161, Validation RMSE: 5.4674, Test RMSE: 12.1918
Training PCC: 0.9873, Validation PCC: 0.9868, Test PCC: 0.7871
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.7932, Validation Total Loss: 5.5477, Test Total Loss: 13.1618
Training Regression Loss: 5.2474, Validation Regression Loss: 5.5477, Test Regression Loss: 13.1618
Training Classification Loss: 2.7293
Training RMSE: 5.0881, Validation RMSE: 5.3016, Test RMSE: 11.8803
Training PCC: 0.9884, Validation PCC: 0.9871, Test PCC: 0.7951
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.6032, Validation Total Loss: 5.4876, Test Total Loss: 13.0453
Training Regression Loss: 5.0430, Validation Regression Loss: 5.4876, Test Regression Loss: 13.0453
Training Classification Loss: 2.8012
Training RMSE: 4.9027, Validation RMSE: 5.1824, Test RMSE: 11.7504
Training PCC: 0.9891, Validation PCC: 0.9871, Test PCC: 0.7899
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.5258, Validation Total Loss: 5.2118, Test Total Loss: 13.3244
Training Regression Loss: 4.9484, Validation Regression Loss: 5.2118, Test Regression Loss: 13.3244
Training Classification Loss: 2.8869
Training RMSE: 4.8140, Validation RMSE: 4.9428, Test RMSE: 12.0864
Training PCC: 0.9896, Validation PCC: 0.9878, Test PCC: 0.7720
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.7519, Validation Total Loss: 5.9006, Test Total Loss: 13.1112
Training Regression Loss: 5.1214, Validation Regression Loss: 5.9006, Test Regression Loss: 13.1112
Training Classification Loss: 3.1525
Training RMSE: 4.9805, Validation RMSE: 5.5606, Test RMSE: 11.9962
Training PCC: 0.9888, Validation PCC: 0.9872, Test PCC: 0.7774
Checkpoint saved for epoch 20
Total training time: 2564.98 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.2_lambda_1.25_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.3_lambda_1.25_wl100_ol75_normalizedbysubject with alpha: 0.3
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.1760, Validation Total Loss: 11.7312, Test Total Loss: 13.4811
Training Regression Loss: 19.3835, Validation Regression Loss: 11.7312, Test Regression Loss: 13.4811
Training Classification Loss: 2.6417
Training RMSE: 18.8305, Validation RMSE: 11.2440, Test RMSE: 12.8480
Training PCC: 0.7936, Validation PCC: 0.9411, Test PCC: 0.6631
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.6518, Validation Total Loss: 9.8075, Test Total Loss: 13.7269
Training Regression Loss: 10.9013, Validation Regression Loss: 9.8075, Test Regression Loss: 13.7269
Training Classification Loss: 2.5019
Training RMSE: 10.4589, Validation RMSE: 9.3020, Test RMSE: 12.7090
Training PCC: 0.9516, Validation PCC: 0.9620, Test PCC: 0.7165
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.2832, Validation Total Loss: 8.9912, Test Total Loss: 14.7609
Training Regression Loss: 9.5591, Validation Regression Loss: 8.9912, Test Regression Loss: 14.7609
Training Classification Loss: 2.4138
Training RMSE: 9.1409, Validation RMSE: 8.4772, Test RMSE: 13.4250
Training PCC: 0.9640, Validation PCC: 0.9686, Test PCC: 0.7203
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.5213, Validation Total Loss: 8.2864, Test Total Loss: 14.2834
Training Regression Loss: 8.7972, Validation Regression Loss: 8.2864, Test Regression Loss: 14.2834
Training Classification Loss: 2.4134
Training RMSE: 8.3864, Validation RMSE: 7.8388, Test RMSE: 13.3673
Training PCC: 0.9693, Validation PCC: 0.9712, Test PCC: 0.7136
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.7783, Validation Total Loss: 7.8558, Test Total Loss: 14.8655
Training Regression Loss: 8.0162, Validation Regression Loss: 7.8558, Test Regression Loss: 14.8655
Training Classification Loss: 2.5404
Training RMSE: 7.6858, Validation RMSE: 7.4715, Test RMSE: 13.9113
Training PCC: 0.9738, Validation PCC: 0.9771, Test PCC: 0.7418
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.5168, Validation Total Loss: 7.5790, Test Total Loss: 12.9052
Training Regression Loss: 7.7523, Validation Regression Loss: 7.5790, Test Regression Loss: 12.9052
Training Classification Loss: 2.5481
Training RMSE: 7.4005, Validation RMSE: 7.1677, Test RMSE: 11.9595
Training PCC: 0.9759, Validation PCC: 0.9774, Test PCC: 0.7539
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.2052, Validation Total Loss: 7.4648, Test Total Loss: 15.1157
Training Regression Loss: 7.4169, Validation Regression Loss: 7.4648, Test Regression Loss: 15.1157
Training Classification Loss: 2.6276
Training RMSE: 7.0837, Validation RMSE: 6.9678, Test RMSE: 13.6970
Training PCC: 0.9777, Validation PCC: 0.9797, Test PCC: 0.7646
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.8946, Validation Total Loss: 6.8891, Test Total Loss: 14.3077
Training Regression Loss: 7.1042, Validation Regression Loss: 6.8891, Test Regression Loss: 14.3077
Training Classification Loss: 2.6349
Training RMSE: 6.7984, Validation RMSE: 6.5028, Test RMSE: 13.2461
Training PCC: 0.9796, Validation PCC: 0.9810, Test PCC: 0.7267
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.7488, Validation Total Loss: 7.1174, Test Total Loss: 13.7612
Training Regression Loss: 6.9251, Validation Regression Loss: 7.1174, Test Regression Loss: 13.7612
Training Classification Loss: 2.7458
Training RMSE: 6.6318, Validation RMSE: 6.7572, Test RMSE: 12.8926
Training PCC: 0.9803, Validation PCC: 0.9794, Test PCC: 0.7704
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.4351, Validation Total Loss: 6.4863, Test Total Loss: 13.7139
Training Regression Loss: 6.6188, Validation Regression Loss: 6.4863, Test Regression Loss: 13.7139
Training Classification Loss: 2.7208
Training RMSE: 6.3262, Validation RMSE: 6.1035, Test RMSE: 12.7277
Training PCC: 0.9821, Validation PCC: 0.9835, Test PCC: 0.7586
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 7.2285, Validation Total Loss: 6.4322, Test Total Loss: 15.3511
Training Regression Loss: 6.3746, Validation Regression Loss: 6.4322, Test Regression Loss: 15.3511
Training Classification Loss: 2.8462
Training RMSE: 6.1246, Validation RMSE: 6.1045, Test RMSE: 14.1028
Training PCC: 0.9830, Validation PCC: 0.9840, Test PCC: 0.7089
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 7.2082, Validation Total Loss: 6.1594, Test Total Loss: 13.5843
Training Regression Loss: 6.3418, Validation Regression Loss: 6.1594, Test Regression Loss: 13.5843
Training Classification Loss: 2.8881
Training RMSE: 6.0887, Validation RMSE: 5.7998, Test RMSE: 12.6023
Training PCC: 0.9833, Validation PCC: 0.9850, Test PCC: 0.7546
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 7.0494, Validation Total Loss: 6.1425, Test Total Loss: 14.7564
Training Regression Loss: 6.0848, Validation Regression Loss: 6.1425, Test Regression Loss: 14.7564
Training Classification Loss: 3.2152
Training RMSE: 5.8577, Validation RMSE: 5.9110, Test RMSE: 13.8353
Training PCC: 0.9844, Validation PCC: 0.9835, Test PCC: 0.7264
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.8395, Validation Total Loss: 5.9573, Test Total Loss: 13.3952
Training Regression Loss: 5.9447, Validation Regression Loss: 5.9573, Test Regression Loss: 13.3952
Training Classification Loss: 2.9829
Training RMSE: 5.7307, Validation RMSE: 5.6661, Test RMSE: 12.1780
Training PCC: 0.9854, Validation PCC: 0.9860, Test PCC: 0.7760
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.7665, Validation Total Loss: 5.8482, Test Total Loss: 14.4185
Training Regression Loss: 5.8412, Validation Regression Loss: 5.8482, Test Regression Loss: 14.4185
Training Classification Loss: 3.0844
Training RMSE: 5.6275, Validation RMSE: 5.5501, Test RMSE: 13.1892
Training PCC: 0.9856, Validation PCC: 0.9870, Test PCC: 0.7649
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.7340, Validation Total Loss: 5.4402, Test Total Loss: 14.1536
Training Regression Loss: 5.7910, Validation Regression Loss: 5.4402, Test Regression Loss: 14.1536
Training Classification Loss: 3.1431
Training RMSE: 5.5760, Validation RMSE: 5.1869, Test RMSE: 13.0361
Training PCC: 0.9861, Validation PCC: 0.9876, Test PCC: 0.7665
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 6.6837, Validation Total Loss: 5.4144, Test Total Loss: 13.4253
Training Regression Loss: 5.6764, Validation Regression Loss: 5.4144, Test Regression Loss: 13.4253
Training Classification Loss: 3.3575
Training RMSE: 5.4767, Validation RMSE: 5.0976, Test RMSE: 12.4999
Training PCC: 0.9867, Validation PCC: 0.9887, Test PCC: 0.7574
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 6.5223, Validation Total Loss: 5.6381, Test Total Loss: 13.8677
Training Regression Loss: 5.4730, Validation Regression Loss: 5.6381, Test Regression Loss: 13.8677
Training Classification Loss: 3.4978
Training RMSE: 5.2958, Validation RMSE: 5.3676, Test RMSE: 12.4215
Training PCC: 0.9874, Validation PCC: 0.9886, Test PCC: 0.7838
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 6.4673, Validation Total Loss: 5.3431, Test Total Loss: 14.0855
Training Regression Loss: 5.3990, Validation Regression Loss: 5.3431, Test Regression Loss: 14.0855
Training Classification Loss: 3.5610
Training RMSE: 5.2414, Validation RMSE: 5.0677, Test RMSE: 13.0707
Training PCC: 0.9879, Validation PCC: 0.9891, Test PCC: 0.7680
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 6.3002, Validation Total Loss: 5.4696, Test Total Loss: 14.0475
Training Regression Loss: 5.1813, Validation Regression Loss: 5.4696, Test Regression Loss: 14.0475
Training Classification Loss: 3.7297
Training RMSE: 5.0340, Validation RMSE: 5.2778, Test RMSE: 13.0468
Training PCC: 0.9887, Validation PCC: 0.9871, Test PCC: 0.7360
Checkpoint saved for epoch 20
Total training time: 2570.91 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.3_lambda_1.25_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.4_lambda_1.25_wl100_ol75_normalizedbysubject with alpha: 0.4
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.0372, Validation Total Loss: 11.5922, Test Total Loss: 14.8147
Training Regression Loss: 18.9656, Validation Regression Loss: 11.5922, Test Regression Loss: 14.8147
Training Classification Loss: 2.6790
Training RMSE: 18.4239, Validation RMSE: 11.2334, Test RMSE: 14.1192
Training PCC: 0.8046, Validation PCC: 0.9429, Test PCC: 0.7477
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 12.1460, Validation Total Loss: 9.2762, Test Total Loss: 14.1250
Training Regression Loss: 11.0583, Validation Regression Loss: 9.2762, Test Regression Loss: 14.1250
Training Classification Loss: 2.7192
Training RMSE: 10.5966, Validation RMSE: 8.8950, Test RMSE: 13.1138
Training PCC: 0.9507, Validation PCC: 0.9649, Test PCC: 0.7355
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.6772, Validation Total Loss: 8.6252, Test Total Loss: 14.7484
Training Regression Loss: 9.6400, Validation Regression Loss: 8.6252, Test Regression Loss: 14.7484
Training Classification Loss: 2.5932
Training RMSE: 9.2201, Validation RMSE: 8.2354, Test RMSE: 13.4278
Training PCC: 0.9629, Validation PCC: 0.9724, Test PCC: 0.7353
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.9150, Validation Total Loss: 8.5805, Test Total Loss: 14.4913
Training Regression Loss: 8.8320, Validation Regression Loss: 8.5805, Test Regression Loss: 14.4913
Training Classification Loss: 2.7076
Training RMSE: 8.4589, Validation RMSE: 8.0709, Test RMSE: 12.8860
Training PCC: 0.9690, Validation PCC: 0.9768, Test PCC: 0.7554
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.4696, Validation Total Loss: 8.0520, Test Total Loss: 12.8176
Training Regression Loss: 8.3654, Validation Regression Loss: 8.0520, Test Regression Loss: 12.8176
Training Classification Loss: 2.7605
Training RMSE: 7.9808, Validation RMSE: 7.5468, Test RMSE: 12.0952
Training PCC: 0.9721, Validation PCC: 0.9793, Test PCC: 0.7524
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.9626, Validation Total Loss: 7.0636, Test Total Loss: 14.1907
Training Regression Loss: 7.8293, Validation Regression Loss: 7.0636, Test Regression Loss: 14.1907
Training Classification Loss: 2.8333
Training RMSE: 7.4873, Validation RMSE: 6.7604, Test RMSE: 13.1172
Training PCC: 0.9756, Validation PCC: 0.9803, Test PCC: 0.7430
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.8682, Validation Total Loss: 7.1061, Test Total Loss: 16.1739
Training Regression Loss: 7.5854, Validation Regression Loss: 7.1061, Test Regression Loss: 16.1739
Training Classification Loss: 3.2072
Training RMSE: 7.2519, Validation RMSE: 6.7559, Test RMSE: 14.4183
Training PCC: 0.9767, Validation PCC: 0.9828, Test PCC: 0.7608
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 8.5655, Validation Total Loss: 6.9451, Test Total Loss: 12.8424
Training Regression Loss: 7.2898, Validation Regression Loss: 6.9451, Test Regression Loss: 12.8424
Training Classification Loss: 3.1893
Training RMSE: 6.9734, Validation RMSE: 6.5547, Test RMSE: 11.6366
Training PCC: 0.9788, Validation PCC: 0.9837, Test PCC: 0.7702
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 8.7376, Validation Total Loss: 6.2999, Test Total Loss: 14.7163
Training Regression Loss: 7.0032, Validation Regression Loss: 6.2999, Test Regression Loss: 14.7163
Training Classification Loss: 4.3360
Training RMSE: 6.7158, Validation RMSE: 6.0364, Test RMSE: 13.3372
Training PCC: 0.9802, Validation PCC: 0.9845, Test PCC: 0.7649
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 6783.0487, Validation Total Loss: 2114.0165, Test Total Loss: 2118.7362
Training Regression Loss: 1434.0227, Validation Regression Loss: 2114.0165, Test Regression Loss: 2118.7362
Training Classification Loss: 13372.5648
Training RMSE: 1244.5994, Validation RMSE: 1686.4127, Test RMSE: 1687.2713
Training PCC: 0.2606, Validation PCC: 0.0139, Test PCC: 0.0112
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 15907.0783, Validation Total Loss: 5926.5966, Test Total Loss: 5924.7485
Training Regression Loss: 3496.1251, Validation Regression Loss: 5926.5966, Test Regression Loss: 5924.7485
Training Classification Loss: 31027.3825
Training RMSE: 2543.4703, Validation RMSE: 4828.2704, Test RMSE: 4828.9109
Training PCC: -0.0048, Validation PCC: 0.0014, Test PCC: 0.0013
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 20172.7319, Validation Total Loss: 3207.3801, Test Total Loss: 3207.9535
Training Regression Loss: 4582.9709, Validation Regression Loss: 3207.3801, Test Regression Loss: 3207.9535
Training Classification Loss: 38974.4020
Training RMSE: 3841.8859, Validation RMSE: 3116.9436, Test RMSE: 3116.4036
Training PCC: 0.0024, Validation PCC: 0.0009, Test PCC: 0.0009
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 23106.4453, Validation Total Loss: 2984.9889, Test Total Loss: 2985.4861
Training Regression Loss: 4676.0330, Validation Regression Loss: 2984.9889, Test Regression Loss: 2985.4861
Training Classification Loss: 46076.0297
Training RMSE: 3643.0857, Validation RMSE: 1772.0372, Test RMSE: 1771.9797
Training PCC: 0.0019, Validation PCC: 0.0010, Test PCC: 0.0016
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 27161.1662, Validation Total Loss: 3443.1201, Test Total Loss: 3437.5919
Training Regression Loss: 4297.0201, Validation Regression Loss: 3443.1201, Test Regression Loss: 3437.5919
Training Classification Loss: 57160.3643
Training RMSE: 2858.6948, Validation RMSE: 2121.0917, Test RMSE: 2112.2220
Training PCC: 0.0040, Validation PCC: 0.0009, Test PCC: 0.0013
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 32419.7256, Validation Total Loss: 5958.3621, Test Total Loss: 5961.4712
Training Regression Loss: 8078.9165, Validation Regression Loss: 5958.3621, Test Regression Loss: 5961.4712
Training Classification Loss: 60852.0218
Training RMSE: 5670.9345, Validation RMSE: 4682.0656, Test RMSE: 4682.7677
Training PCC: 0.0002, Validation PCC: 0.0012, Test PCC: 0.0011
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 36922.7398, Validation Total Loss: 9083.7081, Test Total Loss: 9089.6215
Training Regression Loss: 7935.1503, Validation Regression Loss: 9083.7081, Test Regression Loss: 9089.6215
Training Classification Loss: 72468.9727
Training RMSE: 6708.8575, Validation RMSE: 6299.0545, Test RMSE: 6303.8171
Training PCC: -0.0032, Validation PCC: 0.0003, Test PCC: 0.0012
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 41826.0430, Validation Total Loss: 5041.1148, Test Total Loss: 5040.5063
Training Regression Loss: 8524.5735, Validation Regression Loss: 5041.1148, Test Regression Loss: 5040.5063
Training Classification Loss: 83253.6725
Training RMSE: 6378.1728, Validation RMSE: 2991.5328, Test RMSE: 2989.7234
Training PCC: -0.0016, Validation PCC: 0.0001, Test PCC: 0.0001
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 41999.8162, Validation Total Loss: 5488.8226, Test Total Loss: 5494.1065
Training Regression Loss: 7470.4336, Validation Regression Loss: 5488.8226, Test Regression Loss: 5494.1065
Training Classification Loss: 86323.4549
Training RMSE: 5006.6189, Validation RMSE: 3229.9042, Test RMSE: 3232.2466
Training PCC: -0.0011, Validation PCC: 0.0005, Test PCC: 0.0009
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 46973.3570, Validation Total Loss: 5447.0380, Test Total Loss: 5446.6608
Training Regression Loss: 9897.0517, Validation Regression Loss: 5447.0380, Test Regression Loss: 5446.6608
Training Classification Loss: 92690.7615
Training RMSE: 7351.5211, Validation RMSE: 3262.5793, Test RMSE: 3266.9911
Training PCC: 0.0003, Validation PCC: 0.0020, Test PCC: 0.0040
Checkpoint saved for epoch 19
Stopping early after 19 epochs
Total training time: 2437.22 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.4_lambda_1.25_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.5_lambda_1.25_wl100_ol75_normalizedbysubject with alpha: 0.5
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.9357, Validation Total Loss: 12.5036, Test Total Loss: 14.0225
Training Regression Loss: 19.5677, Validation Regression Loss: 12.5036, Test Regression Loss: 14.0225
Training Classification Loss: 2.7360
Training RMSE: 18.9856, Validation RMSE: 12.0137, Test RMSE: 13.2673
Training PCC: 0.7936, Validation PCC: 0.9327, Test PCC: 0.6979
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 12.8333, Validation Total Loss: 9.9501, Test Total Loss: 13.8145
Training Regression Loss: 11.3531, Validation Regression Loss: 9.9501, Test Regression Loss: 13.8145
Training Classification Loss: 2.9604
Training RMSE: 10.9128, Validation RMSE: 9.5332, Test RMSE: 12.7527
Training PCC: 0.9484, Validation PCC: 0.9565, Test PCC: 0.7173
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 11.0552, Validation Total Loss: 9.3348, Test Total Loss: 14.9561
Training Regression Loss: 9.7113, Validation Regression Loss: 9.3348, Test Regression Loss: 14.9561
Training Classification Loss: 2.6879
Training RMSE: 9.3136, Validation RMSE: 8.9192, Test RMSE: 13.8925
Training PCC: 0.9622, Validation PCC: 0.9641, Test PCC: 0.7446
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 10.1389, Validation Total Loss: 8.2150, Test Total Loss: 13.0337
Training Regression Loss: 8.7726, Validation Regression Loss: 8.2150, Test Regression Loss: 13.0337
Training Classification Loss: 2.7325
Training RMSE: 8.3973, Validation RMSE: 7.8794, Test RMSE: 11.9335
Training PCC: 0.9691, Validation PCC: 0.9725, Test PCC: 0.7703
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.9342, Validation Total Loss: 8.0249, Test Total Loss: 13.2787
Training Regression Loss: 8.3375, Validation Regression Loss: 8.0249, Test Regression Loss: 13.2787
Training Classification Loss: 3.1933
Training RMSE: 7.9744, Validation RMSE: 7.6311, Test RMSE: 12.3198
Training PCC: 0.9726, Validation PCC: 0.9729, Test PCC: 0.7538
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 9.7651, Validation Total Loss: 8.3129, Test Total Loss: 14.3258
Training Regression Loss: 8.0276, Validation Regression Loss: 8.3129, Test Regression Loss: 14.3258
Training Classification Loss: 3.4749
Training RMSE: 7.6808, Validation RMSE: 7.7288, Test RMSE: 13.1811
Training PCC: 0.9748, Validation PCC: 0.9751, Test PCC: 0.7624
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 1999.8267, Validation Total Loss: 2775.4420, Test Total Loss: 2778.8367
Training Regression Loss: 467.8120, Validation Regression Loss: 2775.4420, Test Regression Loss: 2778.8367
Training Classification Loss: 3064.0295
Training RMSE: 377.3319, Validation RMSE: 2383.4136, Test RMSE: 2384.0026
Training PCC: 0.6575, Validation PCC: -0.0229, Test PCC: -0.0224
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 13335.0147, Validation Total Loss: 2623.4498, Test Total Loss: 2617.4945
Training Regression Loss: 2411.7140, Validation Regression Loss: 2623.4498, Test Regression Loss: 2617.4945
Training Classification Loss: 21846.6016
Training RMSE: 1956.0945, Validation RMSE: 1783.3607, Test RMSE: 1777.3618
Training PCC: 0.0011, Validation PCC: 0.0039, Test PCC: 0.0029
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 21936.3464, Validation Total Loss: 3450.8447, Test Total Loss: 3446.7503
Training Regression Loss: 4753.0526, Validation Regression Loss: 3450.8447, Test Regression Loss: 3446.7503
Training Classification Loss: 34366.5877
Training RMSE: 3839.6434, Validation RMSE: 3030.2077, Test RMSE: 3029.4060
Training PCC: -0.0020, Validation PCC: 0.0051, Test PCC: 0.0034
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 25633.9790, Validation Total Loss: 2993.9267, Test Total Loss: 2998.2690
Training Regression Loss: 3530.4512, Validation Regression Loss: 2993.9267, Test Regression Loss: 2998.2690
Training Classification Loss: 44207.0557
Training RMSE: 2403.1917, Validation RMSE: 1780.6269, Test RMSE: 1784.9056
Training PCC: -0.0036, Validation PCC: 0.0023, Test PCC: 0.0041
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 28068.1108, Validation Total Loss: 3339.0419, Test Total Loss: 3333.1337
Training Regression Loss: 4286.3689, Validation Regression Loss: 3339.0419, Test Regression Loss: 3333.1337
Training Classification Loss: 47563.4838
Training RMSE: 3080.4734, Validation RMSE: 2003.3548, Test RMSE: 2000.3789
Training PCC: -0.0014, Validation PCC: 0.0027, Test PCC: 0.0031
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 31132.8292, Validation Total Loss: 3100.4940, Test Total Loss: 3100.3003
Training Regression Loss: 5448.6379, Validation Regression Loss: 3100.4940, Test Regression Loss: 3100.3003
Training Classification Loss: 51368.3826
Training RMSE: 4087.2151, Validation RMSE: 1872.2533, Test RMSE: 1873.5116
Training PCC: 0.0050, Validation PCC: 0.0030, Test PCC: 0.0004
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 36962.0870, Validation Total Loss: 4323.9315, Test Total Loss: 4329.8131
Training Regression Loss: 4898.9357, Validation Regression Loss: 4323.9315, Test Regression Loss: 4329.8131
Training Classification Loss: 64126.3025
Training RMSE: 3161.5655, Validation RMSE: 2545.2535, Test RMSE: 2544.1391
Training PCC: 0.0042, Validation PCC: 0.0023, Test PCC: 0.0065
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 36919.4825, Validation Total Loss: 12145.8359, Test Total Loss: 12145.4301
Training Regression Loss: 5957.7845, Validation Regression Loss: 12145.8359, Test Regression Loss: 12145.4301
Training Classification Loss: 61923.3960
Training RMSE: 3688.5287, Validation RMSE: 8176.3260, Test RMSE: 8170.6073
Training PCC: 0.0018, Validation PCC: 0.0035, Test PCC: 0.0021
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 46146.2760, Validation Total Loss: 4922.3534, Test Total Loss: 4928.5209
Training Regression Loss: 7997.9955, Validation Regression Loss: 4922.3534, Test Regression Loss: 4928.5209
Training Classification Loss: 76296.5610
Training RMSE: 6332.1748, Validation RMSE: 3030.0878, Test RMSE: 3033.3743
Training PCC: -0.0003, Validation PCC: 0.0012, Test PCC: 0.0197
Checkpoint saved for epoch 15
Stopping early after 15 epochs
Total training time: 1926.26 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.5_lambda_1.25_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.1_lambda_1.5_wl100_ol75_normalizedbysubject with alpha: 0.1
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.3277, Validation Total Loss: 12.6068, Test Total Loss: 14.7869
Training Regression Loss: 19.0722, Validation Regression Loss: 12.6068, Test Regression Loss: 14.7869
Training Classification Loss: 2.5553
Training RMSE: 18.5415, Validation RMSE: 11.9615, Test RMSE: 13.8946
Training PCC: 0.8044, Validation PCC: 0.9374, Test PCC: 0.7255
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.3161, Validation Total Loss: 10.3626, Test Total Loss: 13.9624
Training Regression Loss: 11.0998, Validation Regression Loss: 10.3626, Test Regression Loss: 13.9624
Training Classification Loss: 2.1628
Training RMSE: 10.6278, Validation RMSE: 9.8046, Test RMSE: 13.1514
Training PCC: 0.9504, Validation PCC: 0.9607, Test PCC: 0.7510
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.7079, Validation Total Loss: 8.7016, Test Total Loss: 13.9383
Training Regression Loss: 9.5075, Validation Regression Loss: 8.7016, Test Regression Loss: 13.9383
Training Classification Loss: 2.0047
Training RMSE: 9.0902, Validation RMSE: 8.2457, Test RMSE: 13.1471
Training PCC: 0.9643, Validation PCC: 0.9703, Test PCC: 0.6967
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 8.8932, Validation Total Loss: 8.5006, Test Total Loss: 13.4503
Training Regression Loss: 8.6980, Validation Regression Loss: 8.5006, Test Regression Loss: 13.4503
Training Classification Loss: 1.9522
Training RMSE: 8.2962, Validation RMSE: 7.9288, Test RMSE: 12.6649
Training PCC: 0.9702, Validation PCC: 0.9739, Test PCC: 0.7659
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.1933, Validation Total Loss: 7.9605, Test Total Loss: 13.6666
Training Regression Loss: 7.9979, Validation Regression Loss: 7.9605, Test Regression Loss: 13.6666
Training Classification Loss: 1.9539
Training RMSE: 7.6634, Validation RMSE: 7.4890, Test RMSE: 12.9224
Training PCC: 0.9743, Validation PCC: 0.9759, Test PCC: 0.7627
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 7.9995, Validation Total Loss: 7.6394, Test Total Loss: 15.1714
Training Regression Loss: 7.7973, Validation Regression Loss: 7.6394, Test Regression Loss: 15.1714
Training Classification Loss: 2.0219
Training RMSE: 7.4587, Validation RMSE: 7.2506, Test RMSE: 14.0189
Training PCC: 0.9756, Validation PCC: 0.9762, Test PCC: 0.7397
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.4574, Validation Total Loss: 6.7991, Test Total Loss: 14.7526
Training Regression Loss: 7.2512, Validation Regression Loss: 6.7991, Test Regression Loss: 14.7526
Training Classification Loss: 2.0617
Training RMSE: 6.9427, Validation RMSE: 6.4596, Test RMSE: 13.7601
Training PCC: 0.9784, Validation PCC: 0.9818, Test PCC: 0.7591
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.1194, Validation Total Loss: 7.0292, Test Total Loss: 14.4511
Training Regression Loss: 6.9062, Validation Regression Loss: 7.0292, Test Regression Loss: 14.4511
Training Classification Loss: 2.1319
Training RMSE: 6.6216, Validation RMSE: 6.5961, Test RMSE: 13.2227
Training PCC: 0.9804, Validation PCC: 0.9803, Test PCC: 0.7395
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 6.9595, Validation Total Loss: 6.3674, Test Total Loss: 14.6575
Training Regression Loss: 6.7412, Validation Regression Loss: 6.3674, Test Regression Loss: 14.6575
Training Classification Loss: 2.1829
Training RMSE: 6.4736, Validation RMSE: 6.0601, Test RMSE: 13.6565
Training PCC: 0.9812, Validation PCC: 0.9835, Test PCC: 0.7594
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 6.6443, Validation Total Loss: 6.5516, Test Total Loss: 14.1570
Training Regression Loss: 6.4213, Validation Regression Loss: 6.5516, Test Regression Loss: 14.1570
Training Classification Loss: 2.2300
Training RMSE: 6.1871, Validation RMSE: 6.1987, Test RMSE: 13.3073
Training PCC: 0.9829, Validation PCC: 0.9832, Test PCC: 0.7205
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.3586, Validation Total Loss: 6.2458, Test Total Loss: 13.7926
Training Regression Loss: 6.1243, Validation Regression Loss: 6.2458, Test Regression Loss: 13.7926
Training Classification Loss: 2.3427
Training RMSE: 5.9001, Validation RMSE: 5.9457, Test RMSE: 12.5484
Training PCC: 0.9843, Validation PCC: 0.9848, Test PCC: 0.7736
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.2181, Validation Total Loss: 5.8470, Test Total Loss: 15.3697
Training Regression Loss: 5.9797, Validation Regression Loss: 5.8470, Test Regression Loss: 15.3697
Training Classification Loss: 2.3846
Training RMSE: 5.7633, Validation RMSE: 5.5525, Test RMSE: 14.1691
Training PCC: 0.9849, Validation PCC: 0.9859, Test PCC: 0.7722
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.0121, Validation Total Loss: 5.9780, Test Total Loss: 14.2562
Training Regression Loss: 5.7714, Validation Regression Loss: 5.9780, Test Regression Loss: 14.2562
Training Classification Loss: 2.4067
Training RMSE: 5.5830, Validation RMSE: 5.6700, Test RMSE: 13.3441
Training PCC: 0.9858, Validation PCC: 0.9864, Test PCC: 0.7663
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 5.9868, Validation Total Loss: 5.5684, Test Total Loss: 13.7919
Training Regression Loss: 5.7321, Validation Regression Loss: 5.5684, Test Regression Loss: 13.7919
Training Classification Loss: 2.5463
Training RMSE: 5.5393, Validation RMSE: 5.3071, Test RMSE: 12.7564
Training PCC: 0.9862, Validation PCC: 0.9863, Test PCC: 0.7827
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 5.7085, Validation Total Loss: 5.6497, Test Total Loss: 14.0112
Training Regression Loss: 5.4640, Validation Regression Loss: 5.6497, Test Regression Loss: 14.0112
Training Classification Loss: 2.4444
Training RMSE: 5.2969, Validation RMSE: 5.3090, Test RMSE: 13.0016
Training PCC: 0.9873, Validation PCC: 0.9866, Test PCC: 0.7953
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 5.4548, Validation Total Loss: 5.5813, Test Total Loss: 13.4925
Training Regression Loss: 5.2027, Validation Regression Loss: 5.5813, Test Regression Loss: 13.4925
Training Classification Loss: 2.5215
Training RMSE: 5.0521, Validation RMSE: 5.2522, Test RMSE: 12.5398
Training PCC: 0.9883, Validation PCC: 0.9873, Test PCC: 0.8100
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.5707, Validation Total Loss: 5.6947, Test Total Loss: 13.3222
Training Regression Loss: 5.3192, Validation Regression Loss: 5.6947, Test Regression Loss: 13.3222
Training Classification Loss: 2.5148
Training RMSE: 5.1561, Validation RMSE: 5.3361, Test RMSE: 12.3957
Training PCC: 0.9880, Validation PCC: 0.9867, Test PCC: 0.7930
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.3419, Validation Total Loss: 5.6295, Test Total Loss: 14.1630
Training Regression Loss: 5.0892, Validation Regression Loss: 5.6295, Test Regression Loss: 14.1630
Training Classification Loss: 2.5268
Training RMSE: 4.9520, Validation RMSE: 5.3775, Test RMSE: 13.1266
Training PCC: 0.9890, Validation PCC: 0.9878, Test PCC: 0.7881
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.2221, Validation Total Loss: 5.2194, Test Total Loss: 12.8799
Training Regression Loss: 4.9628, Validation Regression Loss: 5.2194, Test Regression Loss: 12.8799
Training Classification Loss: 2.5930
Training RMSE: 4.8266, Validation RMSE: 4.9807, Test RMSE: 11.7555
Training PCC: 0.9896, Validation PCC: 0.9887, Test PCC: 0.8175
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.1209, Validation Total Loss: 5.3602, Test Total Loss: 13.9174
Training Regression Loss: 4.8657, Validation Regression Loss: 5.3602, Test Regression Loss: 13.9174
Training Classification Loss: 2.5520
Training RMSE: 4.7280, Validation RMSE: 5.0511, Test RMSE: 12.7798
Training PCC: 0.9899, Validation PCC: 0.9878, Test PCC: 0.8003
Checkpoint saved for epoch 20
Total training time: 2615.81 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.1_lambda_1.5_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.2_lambda_1.5_wl100_ol75_normalizedbysubject with alpha: 0.2
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.4710, Validation Total Loss: 11.9017, Test Total Loss: 14.1187
Training Regression Loss: 18.9522, Validation Regression Loss: 11.9017, Test Regression Loss: 14.1187
Training Classification Loss: 2.5941
Training RMSE: 18.3940, Validation RMSE: 11.4464, Test RMSE: 13.3963
Training PCC: 0.8073, Validation PCC: 0.9432, Test PCC: 0.7054
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.4679, Validation Total Loss: 9.2587, Test Total Loss: 14.8874
Training Regression Loss: 10.9776, Validation Regression Loss: 9.2587, Test Regression Loss: 14.8874
Training Classification Loss: 2.4516
Training RMSE: 10.5066, Validation RMSE: 8.7930, Test RMSE: 13.7657
Training PCC: 0.9521, Validation PCC: 0.9656, Test PCC: 0.7055
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.9827, Validation Total Loss: 8.1134, Test Total Loss: 13.7034
Training Regression Loss: 9.5089, Validation Regression Loss: 8.1134, Test Regression Loss: 13.7034
Training Classification Loss: 2.3692
Training RMSE: 9.0535, Validation RMSE: 7.7921, Test RMSE: 12.8696
Training PCC: 0.9646, Validation PCC: 0.9739, Test PCC: 0.7388
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.2231, Validation Total Loss: 7.7184, Test Total Loss: 14.0708
Training Regression Loss: 8.7600, Validation Regression Loss: 7.7184, Test Regression Loss: 14.0708
Training Classification Loss: 2.3155
Training RMSE: 8.3637, Validation RMSE: 7.2873, Test RMSE: 13.0195
Training PCC: 0.9701, Validation PCC: 0.9765, Test PCC: 0.7281
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.6112, Validation Total Loss: 7.5449, Test Total Loss: 14.5695
Training Regression Loss: 8.1346, Validation Regression Loss: 7.5449, Test Regression Loss: 14.5695
Training Classification Loss: 2.3831
Training RMSE: 7.7622, Validation RMSE: 7.2881, Test RMSE: 13.3763
Training PCC: 0.9738, Validation PCC: 0.9781, Test PCC: 0.7228
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.3038, Validation Total Loss: 6.8260, Test Total Loss: 14.2599
Training Regression Loss: 7.8063, Validation Regression Loss: 6.8260, Test Regression Loss: 14.2599
Training Classification Loss: 2.4874
Training RMSE: 7.4628, Validation RMSE: 6.5249, Test RMSE: 13.1641
Training PCC: 0.9759, Validation PCC: 0.9799, Test PCC: 0.7156
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.7630, Validation Total Loss: 6.6566, Test Total Loss: 13.5381
Training Regression Loss: 7.2706, Validation Regression Loss: 6.6566, Test Regression Loss: 13.5381
Training Classification Loss: 2.4623
Training RMSE: 6.9437, Validation RMSE: 6.3114, Test RMSE: 12.4861
Training PCC: 0.9783, Validation PCC: 0.9827, Test PCC: 0.7697
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.6212, Validation Total Loss: 6.2938, Test Total Loss: 13.5035
Training Regression Loss: 7.1024, Validation Regression Loss: 6.2938, Test Regression Loss: 13.5035
Training Classification Loss: 2.5940
Training RMSE: 6.8069, Validation RMSE: 6.0277, Test RMSE: 12.2169
Training PCC: 0.9794, Validation PCC: 0.9838, Test PCC: 0.7573
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.3624, Validation Total Loss: 6.2958, Test Total Loss: 13.2791
Training Regression Loss: 6.8435, Validation Regression Loss: 6.2958, Test Regression Loss: 13.2791
Training Classification Loss: 2.5946
Training RMSE: 6.5500, Validation RMSE: 6.0748, Test RMSE: 12.0456
Training PCC: 0.9808, Validation PCC: 0.9840, Test PCC: 0.7602
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.0741, Validation Total Loss: 6.2156, Test Total Loss: 15.1393
Training Regression Loss: 6.5387, Validation Regression Loss: 6.2156, Test Regression Loss: 15.1393
Training Classification Loss: 2.6769
Training RMSE: 6.2604, Validation RMSE: 5.8878, Test RMSE: 13.9405
Training PCC: 0.9825, Validation PCC: 0.9842, Test PCC: 0.7209
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.8637, Validation Total Loss: 5.9067, Test Total Loss: 15.5743
Training Regression Loss: 6.3176, Validation Regression Loss: 5.9067, Test Regression Loss: 15.5743
Training Classification Loss: 2.7307
Training RMSE: 6.0601, Validation RMSE: 5.7029, Test RMSE: 14.2601
Training PCC: 0.9832, Validation PCC: 0.9852, Test PCC: 0.7501
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.8347, Validation Total Loss: 6.0933, Test Total Loss: 14.0036
Training Regression Loss: 6.2542, Validation Regression Loss: 6.0933, Test Regression Loss: 14.0036
Training Classification Loss: 2.9023
Training RMSE: 6.0141, Validation RMSE: 5.8246, Test RMSE: 12.5040
Training PCC: 0.9833, Validation PCC: 0.9849, Test PCC: 0.7681
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.5428, Validation Total Loss: 5.6534, Test Total Loss: 13.8699
Training Regression Loss: 5.9945, Validation Regression Loss: 5.6534, Test Regression Loss: 13.8699
Training Classification Loss: 2.7417
Training RMSE: 5.7644, Validation RMSE: 5.3823, Test RMSE: 12.5151
Training PCC: 0.9851, Validation PCC: 0.9874, Test PCC: 0.7701
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.3015, Validation Total Loss: 5.3676, Test Total Loss: 14.8111
Training Regression Loss: 5.7461, Validation Regression Loss: 5.3676, Test Regression Loss: 14.8111
Training Classification Loss: 2.7770
Training RMSE: 5.5388, Validation RMSE: 5.1727, Test RMSE: 13.4197
Training PCC: 0.9860, Validation PCC: 0.9889, Test PCC: 0.7751
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.2153, Validation Total Loss: 5.5282, Test Total Loss: 13.4870
Training Regression Loss: 5.6535, Validation Regression Loss: 5.5282, Test Regression Loss: 13.4870
Training Classification Loss: 2.8092
Training RMSE: 5.4694, Validation RMSE: 5.3457, Test RMSE: 12.0752
Training PCC: 0.9864, Validation PCC: 0.9866, Test PCC: 0.7796
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 5.9667, Validation Total Loss: 5.3698, Test Total Loss: 13.8053
Training Regression Loss: 5.4074, Validation Regression Loss: 5.3698, Test Regression Loss: 13.8053
Training Classification Loss: 2.7965
Training RMSE: 5.2472, Validation RMSE: 5.1900, Test RMSE: 12.6732
Training PCC: 0.9874, Validation PCC: 0.9894, Test PCC: 0.7786
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.9226, Validation Total Loss: 5.4246, Test Total Loss: 13.8334
Training Regression Loss: 5.3453, Validation Regression Loss: 5.4246, Test Regression Loss: 13.8334
Training Classification Loss: 2.8866
Training RMSE: 5.1839, Validation RMSE: 5.2036, Test RMSE: 12.5329
Training PCC: 0.9880, Validation PCC: 0.9885, Test PCC: 0.7965
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.6832, Validation Total Loss: 5.0637, Test Total Loss: 14.1827
Training Regression Loss: 5.1156, Validation Regression Loss: 5.0637, Test Regression Loss: 14.1827
Training Classification Loss: 2.8380
Training RMSE: 4.9821, Validation RMSE: 4.9148, Test RMSE: 12.7767
Training PCC: 0.9887, Validation PCC: 0.9894, Test PCC: 0.7871
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.6428, Validation Total Loss: 5.0899, Test Total Loss: 13.8483
Training Regression Loss: 5.0622, Validation Regression Loss: 5.0899, Test Regression Loss: 13.8483
Training Classification Loss: 2.9032
Training RMSE: 4.9259, Validation RMSE: 4.9540, Test RMSE: 12.5891
Training PCC: 0.9892, Validation PCC: 0.9893, Test PCC: 0.7987
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.4320, Validation Total Loss: 4.9680, Test Total Loss: 14.4570
Training Regression Loss: 4.8427, Validation Regression Loss: 4.9680, Test Regression Loss: 14.4570
Training Classification Loss: 2.9468
Training RMSE: 4.7305, Validation RMSE: 4.8216, Test RMSE: 13.0954
Training PCC: 0.9899, Validation PCC: 0.9903, Test PCC: 0.8101
Checkpoint saved for epoch 20
Total training time: 2614.50 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.2_lambda_1.5_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.3_lambda_1.5_wl100_ol75_normalizedbysubject with alpha: 0.3
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.0721, Validation Total Loss: 11.3240, Test Total Loss: 15.3617
Training Regression Loss: 19.2674, Validation Regression Loss: 11.3240, Test Regression Loss: 15.3617
Training Classification Loss: 2.6821
Training RMSE: 18.7189, Validation RMSE: 10.9702, Test RMSE: 14.4889
Training PCC: 0.7988, Validation PCC: 0.9461, Test PCC: 0.6666
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 12.1219, Validation Total Loss: 9.5285, Test Total Loss: 15.6856
Training Regression Loss: 11.3372, Validation Regression Loss: 9.5285, Test Regression Loss: 15.6856
Training Classification Loss: 2.6157
Training RMSE: 10.8658, Validation RMSE: 9.1884, Test RMSE: 14.2128
Training PCC: 0.9480, Validation PCC: 0.9645, Test PCC: 0.6901
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.4549, Validation Total Loss: 8.5431, Test Total Loss: 14.9091
Training Regression Loss: 9.7055, Validation Regression Loss: 8.5431, Test Regression Loss: 14.9091
Training Classification Loss: 2.4981
Training RMSE: 9.2973, Validation RMSE: 8.2163, Test RMSE: 13.7757
Training PCC: 0.9623, Validation PCC: 0.9705, Test PCC: 0.6681
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.8142, Validation Total Loss: 7.4035, Test Total Loss: 14.7671
Training Regression Loss: 9.0546, Validation Regression Loss: 7.4035, Test Regression Loss: 14.7671
Training Classification Loss: 2.5322
Training RMSE: 8.6575, Validation RMSE: 7.1081, Test RMSE: 13.6154
Training PCC: 0.9678, Validation PCC: 0.9772, Test PCC: 0.6729
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.0780, Validation Total Loss: 7.3541, Test Total Loss: 15.0795
Training Regression Loss: 8.3007, Validation Regression Loss: 7.3541, Test Regression Loss: 15.0795
Training Classification Loss: 2.5910
Training RMSE: 7.9332, Validation RMSE: 7.0800, Test RMSE: 14.0180
Training PCC: 0.9722, Validation PCC: 0.9783, Test PCC: 0.7116
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.6560, Validation Total Loss: 6.8786, Test Total Loss: 15.1098
Training Regression Loss: 7.8559, Validation Regression Loss: 6.8786, Test Regression Loss: 15.1098
Training Classification Loss: 2.6667
Training RMSE: 7.5168, Validation RMSE: 6.6500, Test RMSE: 13.9926
Training PCC: 0.9752, Validation PCC: 0.9814, Test PCC: 0.7455
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.2873, Validation Total Loss: 6.6546, Test Total Loss: 15.0666
Training Regression Loss: 7.4810, Validation Regression Loss: 6.6546, Test Regression Loss: 15.0666
Training Classification Loss: 2.6877
Training RMSE: 7.1727, Validation RMSE: 6.4180, Test RMSE: 13.9081
Training PCC: 0.9771, Validation PCC: 0.9818, Test PCC: 0.7081
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.8492, Validation Total Loss: 6.4969, Test Total Loss: 15.0928
Training Regression Loss: 7.0557, Validation Regression Loss: 6.4969, Test Regression Loss: 15.0928
Training Classification Loss: 2.6450
Training RMSE: 6.7640, Validation RMSE: 6.2513, Test RMSE: 13.9080
Training PCC: 0.9795, Validation PCC: 0.9833, Test PCC: 0.7408
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.7864, Validation Total Loss: 6.1614, Test Total Loss: 14.6570
Training Regression Loss: 6.9271, Validation Regression Loss: 6.1614, Test Regression Loss: 14.6570
Training Classification Loss: 2.8644
Training RMSE: 6.6364, Validation RMSE: 5.9248, Test RMSE: 13.3094
Training PCC: 0.9804, Validation PCC: 0.9838, Test PCC: 0.7183
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.6426, Validation Total Loss: 6.0562, Test Total Loss: 14.5023
Training Regression Loss: 6.7255, Validation Regression Loss: 6.0562, Test Regression Loss: 14.5023
Training Classification Loss: 3.0570
Training RMSE: 6.4608, Validation RMSE: 5.7851, Test RMSE: 13.0410
Training PCC: 0.9813, Validation PCC: 0.9848, Test PCC: 0.7685
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 7.5114, Validation Total Loss: 6.3526, Test Total Loss: 14.6808
Training Regression Loss: 6.5867, Validation Regression Loss: 6.3526, Test Regression Loss: 14.6808
Training Classification Loss: 3.0823
Training RMSE: 6.3289, Validation RMSE: 6.2154, Test RMSE: 13.2884
Training PCC: 0.9820, Validation PCC: 0.9845, Test PCC: 0.7409
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 7.4125, Validation Total Loss: 5.7171, Test Total Loss: 15.3795
Training Regression Loss: 6.4094, Validation Regression Loss: 5.7171, Test Regression Loss: 15.3795
Training Classification Loss: 3.3436
Training RMSE: 6.1584, Validation RMSE: 5.4866, Test RMSE: 13.9101
Training PCC: 0.9831, Validation PCC: 0.9865, Test PCC: 0.7420
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 7.3920, Validation Total Loss: 6.3589, Test Total Loss: 16.5947
Training Regression Loss: 6.3000, Validation Regression Loss: 6.3589, Test Regression Loss: 16.5947
Training Classification Loss: 3.6401
Training RMSE: 6.0560, Validation RMSE: 6.1098, Test RMSE: 15.0848
Training PCC: 0.9837, Validation PCC: 0.9859, Test PCC: 0.7473
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 7.5016, Validation Total Loss: 5.7014, Test Total Loss: 15.2154
Training Regression Loss: 6.3227, Validation Regression Loss: 5.7014, Test Regression Loss: 15.2154
Training Classification Loss: 3.9296
Training RMSE: 6.0890, Validation RMSE: 5.4100, Test RMSE: 13.5730
Training PCC: 0.9836, Validation PCC: 0.9871, Test PCC: 0.7537
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 7.8435, Validation Total Loss: 6.0740, Test Total Loss: 14.8090
Training Regression Loss: 6.4580, Validation Regression Loss: 6.0740, Test Regression Loss: 14.8090
Training Classification Loss: 4.6183
Training RMSE: 6.2162, Validation RMSE: 5.8486, Test RMSE: 13.3726
Training PCC: 0.9833, Validation PCC: 0.9855, Test PCC: 0.7348
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 8.4644, Validation Total Loss: 5.8275, Test Total Loss: 14.9795
Training Regression Loss: 6.5429, Validation Regression Loss: 5.8275, Test Regression Loss: 14.9795
Training Classification Loss: 6.4049
Training RMSE: 6.3025, Validation RMSE: 5.6594, Test RMSE: 13.6959
Training PCC: 0.9838, Validation PCC: 0.9869, Test PCC: 0.7500
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 7076.2339, Validation Total Loss: 4022.3116, Test Total Loss: 4016.0474
Training Regression Loss: 1707.1870, Validation Regression Loss: 4022.3116, Test Regression Loss: 4016.0474
Training Classification Loss: 17896.8221
Training RMSE: 1412.2922, Validation RMSE: 3887.9582, Test RMSE: 3881.7728
Training PCC: 0.2924, Validation PCC: 0.0017, Test PCC: 0.0007
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 12886.5267, Validation Total Loss: 2762.4950, Test Total Loss: 2755.9193
Training Regression Loss: 3582.3783, Validation Regression Loss: 2762.4950, Test Regression Loss: 2755.9193
Training Classification Loss: 31013.8267
Training RMSE: 3039.2673, Validation RMSE: 2226.8067, Test RMSE: 2220.5236
Training PCC: 0.0044, Validation PCC: 0.0033, Test PCC: 0.0017
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 18198.7905, Validation Total Loss: 2526.4204, Test Total Loss: 2521.9604
Training Regression Loss: 3914.5457, Validation Regression Loss: 2526.4204, Test Regression Loss: 2521.9604
Training Classification Loss: 47614.1470
Training RMSE: 2752.3065, Validation RMSE: 2118.5471, Test RMSE: 2112.5659
Training PCC: -0.0013, Validation PCC: 0.0019, Test PCC: 0.0021
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 22276.7817, Validation Total Loss: 6126.0127, Test Total Loss: 6132.1291
Training Regression Loss: 5979.3371, Validation Regression Loss: 6126.0127, Test Regression Loss: 6132.1291
Training Classification Loss: 54324.8129
Training RMSE: 4159.1485, Validation RMSE: 4741.8506, Test RMSE: 4747.6094
Training PCC: -0.0051, Validation PCC: 0.0004, Test PCC: 0.0022
Checkpoint saved for epoch 20
Total training time: 2562.83 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.3_lambda_1.5_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.4_lambda_1.5_wl100_ol75_normalizedbysubject with alpha: 0.4
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.4950, Validation Total Loss: 12.0203, Test Total Loss: 14.9761
Training Regression Loss: 19.3858, Validation Regression Loss: 12.0203, Test Regression Loss: 14.9761
Training Classification Loss: 2.7730
Training RMSE: 18.8383, Validation RMSE: 11.4166, Test RMSE: 13.9489
Training PCC: 0.7933, Validation PCC: 0.9415, Test PCC: 0.7398
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 12.2560, Validation Total Loss: 9.7731, Test Total Loss: 15.9779
Training Regression Loss: 11.0801, Validation Regression Loss: 9.7731, Test Regression Loss: 15.9779
Training Classification Loss: 2.9397
Training RMSE: 10.6354, Validation RMSE: 9.3248, Test RMSE: 14.6742
Training PCC: 0.9504, Validation PCC: 0.9593, Test PCC: 0.6939
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.7732, Validation Total Loss: 8.8144, Test Total Loss: 14.3344
Training Regression Loss: 9.6818, Validation Regression Loss: 8.8144, Test Regression Loss: 14.3344
Training Classification Loss: 2.7286
Training RMSE: 9.2549, Validation RMSE: 8.3273, Test RMSE: 13.2729
Training PCC: 0.9632, Validation PCC: 0.9698, Test PCC: 0.7269
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.9720, Validation Total Loss: 7.9422, Test Total Loss: 13.7974
Training Regression Loss: 8.8318, Validation Regression Loss: 7.9422, Test Regression Loss: 13.7974
Training Classification Loss: 2.8504
Training RMSE: 8.4483, Validation RMSE: 7.6248, Test RMSE: 12.6676
Training PCC: 0.9689, Validation PCC: 0.9747, Test PCC: 0.7579
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.5713, Validation Total Loss: 8.3446, Test Total Loss: 12.9683
Training Regression Loss: 8.3275, Validation Regression Loss: 8.3446, Test Regression Loss: 12.9683
Training Classification Loss: 3.1093
Training RMSE: 7.9978, Validation RMSE: 7.7323, Test RMSE: 12.1149
Training PCC: 0.9719, Validation PCC: 0.9775, Test PCC: 0.7453
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 9.7474, Validation Total Loss: 7.8167, Test Total Loss: 14.7733
Training Regression Loss: 8.1032, Validation Regression Loss: 7.8167, Test Regression Loss: 14.7733
Training Classification Loss: 4.1104
Training RMSE: 7.7809, Validation RMSE: 7.4968, Test RMSE: 13.7252
Training PCC: 0.9741, Validation PCC: 0.9721, Test PCC: 0.7159
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 4844.0760, Validation Total Loss: 848.9419, Test Total Loss: 848.8275
Training Regression Loss: 1055.3811, Validation Regression Loss: 848.9419, Test Regression Loss: 848.8275
Training Classification Loss: 9471.7372
Training RMSE: 927.9055, Validation RMSE: 772.4663, Test RMSE: 771.4943
Training PCC: 0.3553, Validation PCC: 0.0033, Test PCC: 0.0035
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 14372.6945, Validation Total Loss: 5762.9440, Test Total Loss: 5764.0171
Training Regression Loss: 3249.2225, Validation Regression Loss: 5762.9440, Test Regression Loss: 5764.0171
Training Classification Loss: 27808.6795
Training RMSE: 2668.7426, Validation RMSE: 4994.4568, Test RMSE: 4993.3375
Training PCC: -0.0069, Validation PCC: 0.0011, Test PCC: 0.0029
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 19158.8956, Validation Total Loss: 3683.8494, Test Total Loss: 3688.6669
Training Regression Loss: 5248.8693, Validation Regression Loss: 3683.8494, Test Regression Loss: 3688.6669
Training Classification Loss: 34775.0651
Training RMSE: 4198.2671, Validation RMSE: 3045.5055, Test RMSE: 3046.4610
Training PCC: 0.0048, Validation PCC: 0.0004, Test PCC: 0.0026
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 21069.2098, Validation Total Loss: 2929.3781, Test Total Loss: 2932.8131
Training Regression Loss: 3982.2167, Validation Regression Loss: 2929.3781, Test Regression Loss: 2932.8131
Training Classification Loss: 42717.4818
Training RMSE: 2802.1185, Validation RMSE: 1916.8122, Test RMSE: 1914.9073
Training PCC: 0.0024, Validation PCC: 0.0002, Test PCC: 0.0027
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 23735.8469, Validation Total Loss: 3292.7013, Test Total Loss: 3296.7795
Training Regression Loss: 4795.1709, Validation Regression Loss: 3292.7013, Test Regression Loss: 3296.7795
Training Classification Loss: 47351.6894
Training RMSE: 3774.6567, Validation RMSE: 2693.3517, Test RMSE: 2696.4311
Training PCC: 0.0044, Validation PCC: -0.0018, Test PCC: 0.0021
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 28461.9907, Validation Total Loss: 4605.7294, Test Total Loss: 4605.0780
Training Regression Loss: 4819.8920, Validation Regression Loss: 4605.7294, Test Regression Loss: 4605.0780
Training Classification Loss: 59105.2458
Training RMSE: 3514.6200, Validation RMSE: 2875.2673, Test RMSE: 2869.1440
Training PCC: 0.0033, Validation PCC: -0.0010, Test PCC: 0.0046
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 32597.1148, Validation Total Loss: 4660.4146, Test Total Loss: 4654.9689
Training Regression Loss: 5314.8299, Validation Regression Loss: 4660.4146, Test Regression Loss: 4654.9689
Training Classification Loss: 68205.7112
Training RMSE: 3564.3801, Validation RMSE: 2866.4578, Test RMSE: 2861.8029
Training PCC: 0.0038, Validation PCC: 0.0001, Test PCC: 0.0018
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 39140.1190, Validation Total Loss: 6834.9973, Test Total Loss: 6834.3636
Training Regression Loss: 8009.9024, Validation Regression Loss: 6834.9973, Test Regression Loss: 6834.3636
Training Classification Loss: 77825.5400
Training RMSE: 6479.3456, Validation RMSE: 4230.5966, Test RMSE: 4224.5341
Training PCC: 0.0028, Validation PCC: -0.0018, Test PCC: 0.0047
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 40012.3533, Validation Total Loss: 6800.2798, Test Total Loss: 6796.1814
Training Regression Loss: 8156.3623, Validation Regression Loss: 6800.2798, Test Regression Loss: 6796.1814
Training Classification Loss: 79639.9759
Training RMSE: 5862.5714, Validation RMSE: 4102.4627, Test RMSE: 4102.5678
Training PCC: 0.0021, Validation PCC: -0.0033, Test PCC: 0.0023
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 44815.3518, Validation Total Loss: 11934.9405, Test Total Loss: 11929.2279
Training Regression Loss: 9338.6782, Validation Regression Loss: 11934.9405, Test Regression Loss: 11929.2279
Training Classification Loss: 88691.6823
Training RMSE: 5960.5476, Validation RMSE: 7027.4271, Test RMSE: 7021.2186
Training PCC: 0.0040, Validation PCC: -0.0023, Test PCC: 0.0057
Checkpoint saved for epoch 16
Stopping early after 16 epochs
Total training time: 2035.41 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.4_lambda_1.5_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.5_lambda_1.5_wl100_ol75_normalizedbysubject with alpha: 0.5
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 21.0193, Validation Total Loss: 12.3369, Test Total Loss: 13.4088
Training Regression Loss: 19.5779, Validation Regression Loss: 12.3369, Test Regression Loss: 13.4088
Training Classification Loss: 2.8826
Training RMSE: 19.0196, Validation RMSE: 11.7717, Test RMSE: 12.7080
Training PCC: 0.7913, Validation PCC: 0.9366, Test PCC: 0.7330
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 13.1172, Validation Total Loss: 10.1183, Test Total Loss: 14.1371
Training Regression Loss: 11.4298, Validation Regression Loss: 10.1183, Test Regression Loss: 14.1371
Training Classification Loss: 3.3747
Training RMSE: 10.9911, Validation RMSE: 9.6781, Test RMSE: 13.0657
Training PCC: 0.9465, Validation PCC: 0.9573, Test PCC: 0.7436
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 11.8670, Validation Total Loss: 9.1034, Test Total Loss: 13.6224
Training Regression Loss: 10.1341, Validation Regression Loss: 9.1034, Test Regression Loss: 13.6224
Training Classification Loss: 3.4658
Training RMSE: 9.7290, Validation RMSE: 8.5923, Test RMSE: 12.8270
Training PCC: 0.9593, Validation PCC: 0.9696, Test PCC: 0.7483
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 24.1471, Validation Total Loss: 45.1313, Test Total Loss: 38.9954
Training Regression Loss: 12.5930, Validation Regression Loss: 45.1313, Test Regression Loss: 38.9954
Training Classification Loss: 23.1082
Training RMSE: 11.9107, Validation RMSE: 41.6333, Test RMSE: 34.7711
Training PCC: 0.9559, Validation PCC: 0.7751, Test PCC: 0.5196
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 10528.3354, Validation Total Loss: 2047.4802, Test Total Loss: 2042.7762
Training Regression Loss: 2042.6047, Validation Regression Loss: 2047.4802, Test Regression Loss: 2042.7762
Training Classification Loss: 16971.4614
Training RMSE: 1681.3965, Validation RMSE: 1564.4193, Test RMSE: 1563.5764
Training PCC: 0.0230, Validation PCC: 0.0035, Test PCC: -0.0000
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 18490.1572, Validation Total Loss: 1352.1441, Test Total Loss: 1346.9362
Training Regression Loss: 3521.2372, Validation Regression Loss: 1352.1441, Test Regression Loss: 1346.9362
Training Classification Loss: 29937.8401
Training RMSE: 2614.0558, Validation RMSE: 1144.1275, Test RMSE: 1138.3675
Training PCC: -0.0090, Validation PCC: 0.0015, Test PCC: 0.0028
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 23784.2480, Validation Total Loss: 2015.7301, Test Total Loss: 2009.2248
Training Regression Loss: 5207.0619, Validation Regression Loss: 2015.7301, Test Regression Loss: 2009.2248
Training Classification Loss: 37154.3723
Training RMSE: 4284.3518, Validation RMSE: 1914.5708, Test RMSE: 1908.9037
Training PCC: -0.0010, Validation PCC: 0.0018, Test PCC: 0.0039
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 29789.3822, Validation Total Loss: 3377.9274, Test Total Loss: 3377.7914
Training Regression Loss: 5678.0485, Validation Regression Loss: 3377.9274, Test Regression Loss: 3377.7914
Training Classification Loss: 48222.6676
Training RMSE: 4203.2473, Validation RMSE: 2010.6526, Test RMSE: 2008.8679
Training PCC: -0.0005, Validation PCC: 0.0006, Test PCC: 0.0066
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 34654.6574, Validation Total Loss: 5720.7549, Test Total Loss: 5720.7204
Training Regression Loss: 6929.6217, Validation Regression Loss: 5720.7549, Test Regression Loss: 5720.7204
Training Classification Loss: 55450.0715
Training RMSE: 4830.1597, Validation RMSE: 3360.9449, Test RMSE: 3359.8906
Training PCC: -0.0003, Validation PCC: 0.0094, Test PCC: 0.0122
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 39117.7364, Validation Total Loss: 3513.7754, Test Total Loss: 3518.8185
Training Regression Loss: 6136.4301, Validation Regression Loss: 3513.7754, Test Regression Loss: 3518.8185
Training Classification Loss: 65962.6125
Training RMSE: 4196.0871, Validation RMSE: 2166.7882, Test RMSE: 2165.1120
Training PCC: -0.0033, Validation PCC: 0.0010, Test PCC: 0.0031
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 50257.7453, Validation Total Loss: 5762.7070, Test Total Loss: 5758.6162
Training Regression Loss: 8431.2769, Validation Regression Loss: 5762.7070, Test Regression Loss: 5758.6162
Training Classification Loss: 83652.9366
Training RMSE: 6780.6216, Validation RMSE: 3431.2288, Test RMSE: 3431.1184
Training PCC: 0.0053, Validation PCC: -0.0005, Test PCC: 0.0031
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 58502.7710, Validation Total Loss: 6845.3063, Test Total Loss: 6839.9994
Training Regression Loss: 9272.8388, Validation Regression Loss: 6845.3063, Test Regression Loss: 6839.9994
Training Classification Loss: 98459.8642
Training RMSE: 6897.2611, Validation RMSE: 4831.1415, Test RMSE: 4825.4713
Training PCC: -0.0011, Validation PCC: 0.0038, Test PCC: 0.0010
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 59717.9580, Validation Total Loss: 9163.2345, Test Total Loss: 9163.2053
Training Regression Loss: 9777.2268, Validation Regression Loss: 9163.2345, Test Regression Loss: 9163.2053
Training Classification Loss: 99881.4626
Training RMSE: 6954.1765, Validation RMSE: 5458.4107, Test RMSE: 5457.5391
Training PCC: -0.0009, Validation PCC: -0.0055, Test PCC: 0.0064
Checkpoint saved for epoch 13
Stopping early after 13 epochs
Total training time: 1663.62 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.5_lambda_1.5_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.1_lambda_2_wl100_ol75_normalizedbysubject with alpha: 0.1
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.5403, Validation Total Loss: 11.6397, Test Total Loss: 14.8073
Training Regression Loss: 19.2857, Validation Regression Loss: 11.6397, Test Regression Loss: 14.8073
Training Classification Loss: 2.5459
Training RMSE: 18.7207, Validation RMSE: 11.1932, Test RMSE: 13.8089
Training PCC: 0.7983, Validation PCC: 0.9402, Test PCC: 0.7498
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.2835, Validation Total Loss: 9.4023, Test Total Loss: 13.5738
Training Regression Loss: 11.0586, Validation Regression Loss: 9.4023, Test Regression Loss: 13.5738
Training Classification Loss: 2.2485
Training RMSE: 10.6113, Validation RMSE: 9.0038, Test RMSE: 12.6926
Training PCC: 0.9504, Validation PCC: 0.9623, Test PCC: 0.7658
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.7397, Validation Total Loss: 9.2995, Test Total Loss: 15.4511
Training Regression Loss: 9.5236, Validation Regression Loss: 9.2995, Test Regression Loss: 15.4511
Training Classification Loss: 2.1605
Training RMSE: 9.1041, Validation RMSE: 8.9781, Test RMSE: 14.7308
Training PCC: 0.9640, Validation PCC: 0.9707, Test PCC: 0.7182
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 8.9504, Validation Total Loss: 7.9745, Test Total Loss: 14.7565
Training Regression Loss: 8.7371, Validation Regression Loss: 7.9745, Test Regression Loss: 14.7565
Training Classification Loss: 2.1330
Training RMSE: 8.3367, Validation RMSE: 7.6948, Test RMSE: 13.4425
Training PCC: 0.9701, Validation PCC: 0.9747, Test PCC: 0.6937
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.4762, Validation Total Loss: 7.5256, Test Total Loss: 14.0284
Training Regression Loss: 8.2627, Validation Regression Loss: 7.5256, Test Regression Loss: 14.0284
Training Classification Loss: 2.1343
Training RMSE: 7.8925, Validation RMSE: 7.1758, Test RMSE: 12.7234
Training PCC: 0.9727, Validation PCC: 0.9776, Test PCC: 0.7773
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.0039, Validation Total Loss: 7.0871, Test Total Loss: 15.0689
Training Regression Loss: 7.7805, Validation Regression Loss: 7.0871, Test Regression Loss: 15.0689
Training Classification Loss: 2.2347
Training RMSE: 7.4315, Validation RMSE: 6.7572, Test RMSE: 13.6823
Training PCC: 0.9754, Validation PCC: 0.9805, Test PCC: 0.7767
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.6588, Validation Total Loss: 6.3127, Test Total Loss: 14.4394
Training Regression Loss: 7.4304, Validation Regression Loss: 6.3127, Test Regression Loss: 14.4394
Training Classification Loss: 2.2841
Training RMSE: 7.0934, Validation RMSE: 6.0704, Test RMSE: 13.0085
Training PCC: 0.9779, Validation PCC: 0.9828, Test PCC: 0.7651
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.2079, Validation Total Loss: 6.8479, Test Total Loss: 13.9699
Training Regression Loss: 6.9715, Validation Regression Loss: 6.8479, Test Regression Loss: 13.9699
Training Classification Loss: 2.3636
Training RMSE: 6.6575, Validation RMSE: 6.4137, Test RMSE: 12.7725
Training PCC: 0.9801, Validation PCC: 0.9841, Test PCC: 0.7787
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 6.9168, Validation Total Loss: 5.8514, Test Total Loss: 15.0444
Training Regression Loss: 6.6683, Validation Regression Loss: 5.8514, Test Regression Loss: 15.0444
Training Classification Loss: 2.4853
Training RMSE: 6.3892, Validation RMSE: 5.6387, Test RMSE: 13.3937
Training PCC: 0.9816, Validation PCC: 0.9846, Test PCC: 0.7591
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 6.6836, Validation Total Loss: 5.7491, Test Total Loss: 13.9050
Training Regression Loss: 6.4361, Validation Regression Loss: 5.7491, Test Regression Loss: 13.9050
Training Classification Loss: 2.4750
Training RMSE: 6.1858, Validation RMSE: 5.5182, Test RMSE: 12.3533
Training PCC: 0.9827, Validation PCC: 0.9859, Test PCC: 0.7869
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.5104, Validation Total Loss: 5.3372, Test Total Loss: 14.1650
Training Regression Loss: 6.2518, Validation Regression Loss: 5.3372, Test Regression Loss: 14.1650
Training Classification Loss: 2.5859
Training RMSE: 6.0140, Validation RMSE: 5.1620, Test RMSE: 12.7925
Training PCC: 0.9838, Validation PCC: 0.9881, Test PCC: 0.8037
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.2191, Validation Total Loss: 5.4103, Test Total Loss: 12.9459
Training Regression Loss: 5.9687, Validation Regression Loss: 5.4103, Test Regression Loss: 12.9459
Training Classification Loss: 2.5048
Training RMSE: 5.7583, Validation RMSE: 5.2184, Test RMSE: 11.6510
Training PCC: 0.9850, Validation PCC: 0.9873, Test PCC: 0.7943
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.0249, Validation Total Loss: 5.3288, Test Total Loss: 13.6617
Training Regression Loss: 5.7669, Validation Regression Loss: 5.3288, Test Regression Loss: 13.6617
Training Classification Loss: 2.5799
Training RMSE: 5.5625, Validation RMSE: 5.1196, Test RMSE: 12.3675
Training PCC: 0.9860, Validation PCC: 0.9886, Test PCC: 0.8049
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 5.8655, Validation Total Loss: 5.0774, Test Total Loss: 13.6494
Training Regression Loss: 5.6018, Validation Regression Loss: 5.0774, Test Regression Loss: 13.6494
Training Classification Loss: 2.6368
Training RMSE: 5.4209, Validation RMSE: 4.9128, Test RMSE: 12.3955
Training PCC: 0.9867, Validation PCC: 0.9890, Test PCC: 0.7923
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 5.6951, Validation Total Loss: 5.1594, Test Total Loss: 14.2093
Training Regression Loss: 5.4258, Validation Regression Loss: 5.1594, Test Regression Loss: 14.2093
Training Classification Loss: 2.6924
Training RMSE: 5.2555, Validation RMSE: 4.9374, Test RMSE: 12.6182
Training PCC: 0.9877, Validation PCC: 0.9888, Test PCC: 0.7867
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 5.6042, Validation Total Loss: 4.9787, Test Total Loss: 13.2208
Training Regression Loss: 5.3425, Validation Regression Loss: 4.9787, Test Regression Loss: 13.2208
Training Classification Loss: 2.6179
Training RMSE: 5.1770, Validation RMSE: 4.7902, Test RMSE: 12.0750
Training PCC: 0.9880, Validation PCC: 0.9901, Test PCC: 0.8081
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.4717, Validation Total Loss: 4.6404, Test Total Loss: 13.1503
Training Regression Loss: 5.2011, Validation Regression Loss: 4.6404, Test Regression Loss: 13.1503
Training Classification Loss: 2.7053
Training RMSE: 5.0488, Validation RMSE: 4.5422, Test RMSE: 11.8356
Training PCC: 0.9885, Validation PCC: 0.9901, Test PCC: 0.8058
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.3576, Validation Total Loss: 4.6475, Test Total Loss: 13.0346
Training Regression Loss: 5.0970, Validation Regression Loss: 4.6475, Test Regression Loss: 13.0346
Training Classification Loss: 2.6058
Training RMSE: 4.9546, Validation RMSE: 4.4664, Test RMSE: 11.7515
Training PCC: 0.9889, Validation PCC: 0.9913, Test PCC: 0.8099
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.2432, Validation Total Loss: 4.7201, Test Total Loss: 13.7307
Training Regression Loss: 4.9733, Validation Regression Loss: 4.7201, Test Regression Loss: 13.7307
Training Classification Loss: 2.6995
Training RMSE: 4.8351, Validation RMSE: 4.5857, Test RMSE: 12.1886
Training PCC: 0.9894, Validation PCC: 0.9907, Test PCC: 0.8051
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.0732, Validation Total Loss: 4.7600, Test Total Loss: 13.4868
Training Regression Loss: 4.8180, Validation Regression Loss: 4.7600, Test Regression Loss: 13.4868
Training Classification Loss: 2.5519
Training RMSE: 4.6944, Validation RMSE: 4.5906, Test RMSE: 11.8836
Training PCC: 0.9901, Validation PCC: 0.9911, Test PCC: 0.8090
Checkpoint saved for epoch 20
Total training time: 2513.49 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.1_lambda_2_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.2_lambda_2_wl100_ol75_normalizedbysubject with alpha: 0.2
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.6464, Validation Total Loss: 11.8558, Test Total Loss: 14.4115
Training Regression Loss: 19.1245, Validation Regression Loss: 11.8558, Test Regression Loss: 14.4115
Training Classification Loss: 2.6093
Training RMSE: 18.5885, Validation RMSE: 11.2957, Test RMSE: 13.6367
Training PCC: 0.8008, Validation PCC: 0.9420, Test PCC: 0.6976
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.5837, Validation Total Loss: 10.2635, Test Total Loss: 15.2338
Training Regression Loss: 11.0849, Validation Regression Loss: 10.2635, Test Regression Loss: 15.2338
Training Classification Loss: 2.4942
Training RMSE: 10.6382, Validation RMSE: 9.7840, Test RMSE: 14.3702
Training PCC: 0.9497, Validation PCC: 0.9594, Test PCC: 0.7529
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.0656, Validation Total Loss: 8.8002, Test Total Loss: 14.2674
Training Regression Loss: 9.5702, Validation Regression Loss: 8.8002, Test Regression Loss: 14.2674
Training Classification Loss: 2.4769
Training RMSE: 9.1573, Validation RMSE: 8.3211, Test RMSE: 13.4099
Training PCC: 0.9636, Validation PCC: 0.9684, Test PCC: 0.6868
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.3077, Validation Total Loss: 8.3253, Test Total Loss: 15.0957
Training Regression Loss: 8.8193, Validation Regression Loss: 8.3253, Test Regression Loss: 15.0957
Training Classification Loss: 2.4418
Training RMSE: 8.4321, Validation RMSE: 7.8589, Test RMSE: 13.8404
Training PCC: 0.9690, Validation PCC: 0.9733, Test PCC: 0.7693
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.7076, Validation Total Loss: 7.5626, Test Total Loss: 13.8127
Training Regression Loss: 8.1903, Validation Regression Loss: 7.5626, Test Regression Loss: 13.8127
Training Classification Loss: 2.5867
Training RMSE: 7.8208, Validation RMSE: 7.1474, Test RMSE: 12.6500
Training PCC: 0.9731, Validation PCC: 0.9780, Test PCC: 0.7559
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.3225, Validation Total Loss: 7.3381, Test Total Loss: 13.7817
Training Regression Loss: 7.7758, Validation Regression Loss: 7.3381, Test Regression Loss: 13.7817
Training Classification Loss: 2.7333
Training RMSE: 7.4400, Validation RMSE: 6.8786, Test RMSE: 12.5960
Training PCC: 0.9756, Validation PCC: 0.9786, Test PCC: 0.7714
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.0890, Validation Total Loss: 6.6377, Test Total Loss: 14.8553
Training Regression Loss: 7.5484, Validation Regression Loss: 6.6377, Test Regression Loss: 14.8553
Training Classification Loss: 2.7031
Training RMSE: 7.1973, Validation RMSE: 6.3238, Test RMSE: 13.6746
Training PCC: 0.9772, Validation PCC: 0.9815, Test PCC: 0.7666
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.6382, Validation Total Loss: 7.1066, Test Total Loss: 13.3142
Training Regression Loss: 7.0861, Validation Regression Loss: 7.1066, Test Regression Loss: 13.3142
Training Classification Loss: 2.7605
Training RMSE: 6.7925, Validation RMSE: 6.7255, Test RMSE: 12.1240
Training PCC: 0.9796, Validation PCC: 0.9804, Test PCC: 0.7525
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.3833, Validation Total Loss: 7.1315, Test Total Loss: 15.5570
Training Regression Loss: 6.8299, Validation Regression Loss: 7.1315, Test Regression Loss: 15.5570
Training Classification Loss: 2.7670
Training RMSE: 6.5470, Validation RMSE: 6.7421, Test RMSE: 14.2172
Training PCC: 0.9809, Validation PCC: 0.9828, Test PCC: 0.7709
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.3109, Validation Total Loss: 6.6299, Test Total Loss: 12.9146
Training Regression Loss: 6.7146, Validation Regression Loss: 6.6299, Test Regression Loss: 12.9146
Training Classification Loss: 2.9813
Training RMSE: 6.4331, Validation RMSE: 6.2198, Test RMSE: 12.0727
Training PCC: 0.9814, Validation PCC: 0.9839, Test PCC: 0.7690
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 7.1970, Validation Total Loss: 6.0234, Test Total Loss: 13.5227
Training Regression Loss: 6.5844, Validation Regression Loss: 6.0234, Test Regression Loss: 13.5227
Training Classification Loss: 3.0633
Training RMSE: 6.3230, Validation RMSE: 5.7279, Test RMSE: 12.3713
Training PCC: 0.9823, Validation PCC: 0.9850, Test PCC: 0.7652
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.8365, Validation Total Loss: 6.2101, Test Total Loss: 14.7773
Training Regression Loss: 6.2341, Validation Regression Loss: 6.2101, Test Regression Loss: 14.7773
Training Classification Loss: 3.0121
Training RMSE: 6.0006, Validation RMSE: 5.9927, Test RMSE: 13.7922
Training PCC: 0.9839, Validation PCC: 0.9855, Test PCC: 0.7581
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.7169, Validation Total Loss: 6.2086, Test Total Loss: 14.4481
Training Regression Loss: 6.1058, Validation Regression Loss: 6.2086, Test Regression Loss: 14.4481
Training Classification Loss: 3.0557
Training RMSE: 5.8942, Validation RMSE: 6.0397, Test RMSE: 13.1744
Training PCC: 0.9845, Validation PCC: 0.9861, Test PCC: 0.7767
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.5715, Validation Total Loss: 5.5921, Test Total Loss: 14.1619
Training Regression Loss: 5.9423, Validation Regression Loss: 5.5921, Test Regression Loss: 14.1619
Training Classification Loss: 3.1461
Training RMSE: 5.7453, Validation RMSE: 5.3193, Test RMSE: 12.9388
Training PCC: 0.9853, Validation PCC: 0.9873, Test PCC: 0.7798
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.4769, Validation Total Loss: 5.3141, Test Total Loss: 13.4622
Training Regression Loss: 5.8431, Validation Regression Loss: 5.3141, Test Regression Loss: 13.4622
Training Classification Loss: 3.1688
Training RMSE: 5.6543, Validation RMSE: 5.0888, Test RMSE: 12.2464
Training PCC: 0.9860, Validation PCC: 0.9883, Test PCC: 0.7963
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.3011, Validation Total Loss: 5.7222, Test Total Loss: 13.2059
Training Regression Loss: 5.6414, Validation Regression Loss: 5.7222, Test Regression Loss: 13.2059
Training Classification Loss: 3.2986
Training RMSE: 5.4634, Validation RMSE: 5.4506, Test RMSE: 12.1750
Training PCC: 0.9868, Validation PCC: 0.9878, Test PCC: 0.7963
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 6.2435, Validation Total Loss: 5.1061, Test Total Loss: 14.3539
Training Regression Loss: 5.5465, Validation Regression Loss: 5.1061, Test Regression Loss: 14.3539
Training Classification Loss: 3.4853
Training RMSE: 5.3743, Validation RMSE: 4.9474, Test RMSE: 13.3252
Training PCC: 0.9874, Validation PCC: 0.9899, Test PCC: 0.7672
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 6.1118, Validation Total Loss: 5.2474, Test Total Loss: 15.3598
Training Regression Loss: 5.3635, Validation Regression Loss: 5.2474, Test Regression Loss: 15.3598
Training Classification Loss: 3.7417
Training RMSE: 5.2312, Validation RMSE: 5.0795, Test RMSE: 13.9157
Training PCC: 0.9880, Validation PCC: 0.9885, Test PCC: 0.7792
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 6.3614, Validation Total Loss: 5.1778, Test Total Loss: 13.4956
Training Regression Loss: 5.4432, Validation Regression Loss: 5.1778, Test Regression Loss: 13.4956
Training Classification Loss: 4.5909
Training RMSE: 5.2902, Validation RMSE: 5.0682, Test RMSE: 12.3596
Training PCC: 0.9882, Validation PCC: 0.9871, Test PCC: 0.7763
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 8.1309, Validation Total Loss: 54.4659, Test Total Loss: 57.1063
Training Regression Loss: 6.0171, Validation Regression Loss: 54.4659, Test Regression Loss: 57.1063
Training Classification Loss: 10.5690
Training RMSE: 5.7985, Validation RMSE: 42.2983, Test RMSE: 42.9189
Training PCC: 0.9875, Validation PCC: 0.9539, Test PCC: 0.6840
Checkpoint saved for epoch 20
Total training time: 2516.00 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.2_lambda_2_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.3_lambda_2_wl100_ol75_normalizedbysubject with alpha: 0.3
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.9042, Validation Total Loss: 11.5280, Test Total Loss: 14.7853
Training Regression Loss: 19.0638, Validation Regression Loss: 11.5280, Test Regression Loss: 14.7853
Training Classification Loss: 2.8012
Training RMSE: 18.5194, Validation RMSE: 11.0392, Test RMSE: 14.0160
Training PCC: 0.8051, Validation PCC: 0.9433, Test PCC: 0.6643
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 12.1685, Validation Total Loss: 9.2958, Test Total Loss: 14.7131
Training Regression Loss: 11.3000, Validation Regression Loss: 9.2958, Test Regression Loss: 14.7131
Training Classification Loss: 2.8947
Training RMSE: 10.8335, Validation RMSE: 8.9746, Test RMSE: 13.7014
Training PCC: 0.9491, Validation PCC: 0.9610, Test PCC: 0.7512
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.6522, Validation Total Loss: 8.8635, Test Total Loss: 16.2583
Training Regression Loss: 9.8075, Validation Regression Loss: 8.8635, Test Regression Loss: 16.2583
Training Classification Loss: 2.8156
Training RMSE: 9.3868, Validation RMSE: 8.4472, Test RMSE: 14.7514
Training PCC: 0.9617, Validation PCC: 0.9727, Test PCC: 0.7317
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.8464, Validation Total Loss: 7.7094, Test Total Loss: 13.9981
Training Regression Loss: 9.0065, Validation Regression Loss: 7.7094, Test Regression Loss: 13.9981
Training Classification Loss: 2.7997
Training RMSE: 8.6233, Validation RMSE: 7.3887, Test RMSE: 12.9303
Training PCC: 0.9678, Validation PCC: 0.9754, Test PCC: 0.7598
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.4523, Validation Total Loss: 7.4745, Test Total Loss: 13.5427
Training Regression Loss: 8.4955, Validation Regression Loss: 7.4745, Test Regression Loss: 13.5427
Training Classification Loss: 3.1894
Training RMSE: 8.1259, Validation RMSE: 7.0327, Test RMSE: 12.5048
Training PCC: 0.9716, Validation PCC: 0.9794, Test PCC: 0.7554
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 9.0992, Validation Total Loss: 7.9429, Test Total Loss: 15.6662
Training Regression Loss: 8.0077, Validation Regression Loss: 7.9429, Test Regression Loss: 15.6662
Training Classification Loss: 3.6381
Training RMSE: 7.6657, Validation RMSE: 7.3537, Test RMSE: 14.0657
Training PCC: 0.9747, Validation PCC: 0.9811, Test PCC: 0.7389
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 346.0851, Validation Total Loss: 1275.8546, Test Total Loss: 1281.9416
Training Regression Loss: 92.9707, Validation Regression Loss: 1275.8546, Test Regression Loss: 1281.9416
Training Classification Loss: 843.7146
Training RMSE: 86.0882, Validation RMSE: 1175.0031, Test RMSE: 1180.0421
Training PCC: 0.8807, Validation PCC: 0.1763, Test PCC: 0.1123
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 9944.2264, Validation Total Loss: 2264.5592, Test Total Loss: 2263.7680
Training Regression Loss: 2956.9377, Validation Regression Loss: 2264.5592, Test Regression Loss: 2263.7680
Training Classification Loss: 23290.9616
Training RMSE: 2475.3972, Validation RMSE: 1839.1221, Test RMSE: 1838.4966
Training PCC: 0.0016, Validation PCC: 0.0029, Test PCC: 0.0062
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 15283.2995, Validation Total Loss: 2725.0022, Test Total Loss: 2720.0757
Training Regression Loss: 4374.4155, Validation Regression Loss: 2725.0022, Test Regression Loss: 2720.0757
Training Classification Loss: 36362.9451
Training RMSE: 3570.7660, Validation RMSE: 2034.5187, Test RMSE: 2028.9636
Training PCC: -0.0046, Validation PCC: 0.0008, Test PCC: 0.0013
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 18118.6122, Validation Total Loss: 7488.4935, Test Total Loss: 7493.2917
Training Regression Loss: 4668.5448, Validation Regression Loss: 7488.4935, Test Regression Loss: 7493.2917
Training Classification Loss: 44833.5564
Training RMSE: 3670.0243, Validation RMSE: 6531.8745, Test RMSE: 6537.0912
Training PCC: -0.0013, Validation PCC: 0.0014, Test PCC: 0.0008
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 24249.7020, Validation Total Loss: 4452.0589, Test Total Loss: 4447.7153
Training Regression Loss: 7936.7238, Validation Regression Loss: 4452.0589, Test Regression Loss: 4447.7153
Training Classification Loss: 54376.5918
Training RMSE: 6519.2344, Validation RMSE: 3147.7733, Test RMSE: 3147.0581
Training PCC: 0.0050, Validation PCC: 0.0015, Test PCC: 0.0010
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 22802.9242, Validation Total Loss: 4650.2624, Test Total Loss: 4655.9163
Training Regression Loss: 5123.8856, Validation Regression Loss: 4650.2624, Test Regression Loss: 4655.9163
Training Classification Loss: 58930.1265
Training RMSE: 3519.4276, Validation RMSE: 3412.2494, Test RMSE: 3417.4344
Training PCC: -0.0003, Validation PCC: 0.0012, Test PCC: 0.0009
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 29072.8879, Validation Total Loss: 6212.0514, Test Total Loss: 6208.6394
Training Regression Loss: 6820.5967, Validation Regression Loss: 6212.0514, Test Regression Loss: 6208.6394
Training Classification Loss: 74174.3007
Training RMSE: 5099.9380, Validation RMSE: 5285.8156, Test RMSE: 5280.5060
Training PCC: -0.0012, Validation PCC: 0.0014, Test PCC: 0.0024
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 30360.3580, Validation Total Loss: 5058.7548, Test Total Loss: 5058.7319
Training Regression Loss: 6855.9232, Validation Regression Loss: 5058.7548, Test Regression Loss: 5058.7319
Training Classification Loss: 78348.1128
Training RMSE: 5090.4299, Validation RMSE: 2956.1127, Test RMSE: 2958.6761
Training PCC: 0.0025, Validation PCC: -0.0001, Test PCC: 0.0029
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 37071.9305, Validation Total Loss: 5211.6986, Test Total Loss: 5217.2511
Training Regression Loss: 10209.4331, Validation Regression Loss: 5211.6986, Test Regression Loss: 5217.2511
Training Classification Loss: 89541.6543
Training RMSE: 7807.7319, Validation RMSE: 3410.0915, Test RMSE: 3413.8024
Training PCC: 0.0017, Validation PCC: 0.0028, Test PCC: 0.0022
Checkpoint saved for epoch 15
Stopping early after 15 epochs
Total training time: 1875.25 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.3_lambda_2_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.4_lambda_2_wl100_ol75_normalizedbysubject with alpha: 0.4
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.3770, Validation Total Loss: 12.9512, Test Total Loss: 15.8407
Training Regression Loss: 19.1588, Validation Regression Loss: 12.9512, Test Regression Loss: 15.8407
Training Classification Loss: 3.0456
Training RMSE: 18.6289, Validation RMSE: 12.3352, Test RMSE: 14.9085
Training PCC: 0.8027, Validation PCC: 0.9392, Test PCC: 0.6746
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 13.4697, Validation Total Loss: 10.5374, Test Total Loss: 13.0921
Training Regression Loss: 11.8425, Validation Regression Loss: 10.5374, Test Regression Loss: 13.0921
Training Classification Loss: 4.0681
Training RMSE: 11.3812, Validation RMSE: 10.1310, Test RMSE: 12.3779
Training PCC: 0.9439, Validation PCC: 0.9551, Test PCC: 0.7513
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 81.0658, Validation Total Loss: 224.8709, Test Total Loss: 229.2741
Training Regression Loss: 23.9152, Validation Regression Loss: 224.8709, Test Regression Loss: 229.2741
Training Classification Loss: 142.8764
Training RMSE: 22.5071, Validation RMSE: 224.6983, Test RMSE: 228.7306
Training PCC: 0.9025, Validation PCC: 0.3490, Test PCC: 0.1773
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 5920.0476, Validation Total Loss: 1281.9524, Test Total Loss: 1282.6444
Training Regression Loss: 1367.6892, Validation Regression Loss: 1281.9524, Test Regression Loss: 1282.6444
Training Classification Loss: 11380.8958
Training RMSE: 1078.6516, Validation RMSE: 1266.1841, Test RMSE: 1265.9103
Training PCC: 0.0097, Validation PCC: 0.0181, Test PCC: 0.0089
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 11631.8754, Validation Total Loss: 961.6911, Test Total Loss: 957.6136
Training Regression Loss: 2316.3556, Validation Regression Loss: 961.6911, Test Regression Loss: 957.6136
Training Classification Loss: 23288.7993
Training RMSE: 1855.8433, Validation RMSE: 935.5569, Test RMSE: 930.2703
Training PCC: -0.0016, Validation PCC: -0.0037, Test PCC: -0.0031
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 14658.3900, Validation Total Loss: 3108.3558, Test Total Loss: 3105.8172
Training Regression Loss: 3513.4133, Validation Regression Loss: 3108.3558, Test Regression Loss: 3105.8172
Training Classification Loss: 27862.4413
Training RMSE: 2885.0000, Validation RMSE: 2511.9574, Test RMSE: 2512.6573
Training PCC: -0.0039, Validation PCC: 0.0024, Test PCC: 0.0037
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 21383.2435, Validation Total Loss: 5512.4705, Test Total Loss: 5517.2289
Training Regression Loss: 4410.3933, Validation Regression Loss: 5512.4705, Test Regression Loss: 5517.2289
Training Classification Loss: 42432.1244
Training RMSE: 3347.9453, Validation RMSE: 3485.0410, Test RMSE: 3489.9868
Training PCC: 0.0039, Validation PCC: 0.0001, Test PCC: 0.0089
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 23558.0444, Validation Total Loss: 3603.2774, Test Total Loss: 3608.3034
Training Regression Loss: 4231.6607, Validation Regression Loss: 3603.2774, Test Regression Loss: 3608.3034
Training Classification Loss: 48315.9586
Training RMSE: 2793.8163, Validation RMSE: 2195.9318, Test RMSE: 2195.1428
Training PCC: -0.0015, Validation PCC: -0.0043, Test PCC: 0.0009
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 31719.1797, Validation Total Loss: 8788.0723, Test Total Loss: 8786.1133
Training Regression Loss: 7725.1300, Validation Regression Loss: 8788.0723, Test Regression Loss: 8786.1133
Training Classification Loss: 59985.1233
Training RMSE: 5967.8035, Validation RMSE: 7035.2543, Test RMSE: 7034.5084
Training PCC: -0.0008, Validation PCC: 0.0010, Test PCC: 0.0056
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 36100.3094, Validation Total Loss: 8916.3000, Test Total Loss: 8921.4435
Training Regression Loss: 7620.4553, Validation Regression Loss: 8916.3000, Test Regression Loss: 8921.4435
Training Classification Loss: 71199.6342
Training RMSE: 5651.0976, Validation RMSE: 5227.6957, Test RMSE: 5229.6609
Training PCC: 0.0009, Validation PCC: 0.0058, Test PCC: 0.0020
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 42220.3362, Validation Total Loss: 6760.5278, Test Total Loss: 6764.5318
Training Regression Loss: 10922.0115, Validation Regression Loss: 6760.5278, Test Regression Loss: 6764.5318
Training Classification Loss: 78245.8106
Training RMSE: 7042.2650, Validation RMSE: 5645.1033, Test RMSE: 5645.8674
Training PCC: 0.0029, Validation PCC: -0.0010, Test PCC: 0.0093
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 44591.0203, Validation Total Loss: 5877.9103, Test Total Loss: 5882.2202
Training Regression Loss: 8415.4105, Validation Regression Loss: 5877.9103, Test Regression Loss: 5882.2202
Training Classification Loss: 90439.0228
Training RMSE: 5374.2585, Validation RMSE: 3466.3610, Test RMSE: 3468.0596
Training PCC: 0.0044, Validation PCC: 0.0002, Test PCC: 0.0127
Checkpoint saved for epoch 12
Stopping early after 12 epochs
Total training time: 1508.51 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.4_lambda_2_wl100_ol75_normalizedbysubject.
Running model: TeacherModel_DomainInvariant_alpha_0.5_lambda_2_wl100_ol75_normalizedbysubject with alpha: 0.5
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 22.2001, Validation Total Loss: 14.4240, Test Total Loss: 15.6232
Training Regression Loss: 20.0100, Validation Regression Loss: 14.4240, Test Regression Loss: 15.6232
Training Classification Loss: 4.3803
Training RMSE: 19.4644, Validation RMSE: 14.0240, Test RMSE: 14.7493
Training PCC: 0.7834, Validation PCC: 0.9064, Test PCC: 0.6337
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 428.4333, Validation Total Loss: 563.4738, Test Total Loss: 567.4930
Training Regression Loss: 97.8880, Validation Regression Loss: 563.4738, Test Regression Loss: 567.4930
Training Classification Loss: 661.0905
Training RMSE: 82.3196, Validation RMSE: 536.3237, Test RMSE: 540.9347
Training PCC: 0.6822, Validation PCC: -0.0386, Test PCC: -0.1154
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 4900.3153, Validation Total Loss: 1704.7744, Test Total Loss: 1705.4134
Training Regression Loss: 940.4593, Validation Regression Loss: 1704.7744, Test Regression Loss: 1705.4134
Training Classification Loss: 7919.7119
Training RMSE: 783.9445, Validation RMSE: 1411.8526, Test RMSE: 1410.5052
Training PCC: -0.0134, Validation PCC: -0.0247, Test PCC: -0.0384
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9144.3101, Validation Total Loss: 1104.6791, Test Total Loss: 1108.9703
Training Regression Loss: 2350.7557, Validation Regression Loss: 1104.6791, Test Regression Loss: 1108.9703
Training Classification Loss: 13587.1087
Training RMSE: 1702.0603, Validation RMSE: 919.7565, Test RMSE: 920.3095
Training PCC: -0.0388, Validation PCC: -0.0019, Test PCC: -0.0217
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 11250.1334, Validation Total Loss: 1925.0044, Test Total Loss: 1921.3191
Training Regression Loss: 2086.1353, Validation Regression Loss: 1925.0044, Test Regression Loss: 1921.3191
Training Classification Loss: 18327.9962
Training RMSE: 1534.8541, Validation RMSE: 1433.2228, Test RMSE: 1433.7359
Training PCC: 0.0224, Validation PCC: 0.0874, Test PCC: 0.1175
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 16231.9847, Validation Total Loss: 2098.4635, Test Total Loss: 2097.0995
Training Regression Loss: 2286.5335, Validation Regression Loss: 2098.4635, Test Regression Loss: 2097.0995
Training Classification Loss: 27890.9023
Training RMSE: 1588.1542, Validation RMSE: 1373.6851, Test RMSE: 1366.9889
Training PCC: 0.0313, Validation PCC: 0.0878, Test PCC: 0.0849
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 22323.9934, Validation Total Loss: 2217.7659, Test Total Loss: 2223.0020
Training Regression Loss: 3645.7978, Validation Regression Loss: 2217.7659, Test Regression Loss: 2223.0020
Training Classification Loss: 37356.3911
Training RMSE: 2650.7363, Validation RMSE: 1381.4933, Test RMSE: 1385.3574
Training PCC: 0.0061, Validation PCC: 0.0470, Test PCC: 0.0239
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 27073.2578, Validation Total Loss: 8070.6808, Test Total Loss: 8077.1066
Training Regression Loss: 6340.6493, Validation Regression Loss: 8070.6808, Test Regression Loss: 8077.1066
Training Classification Loss: 41465.2171
Training RMSE: 4599.2340, Validation RMSE: 4714.3175, Test RMSE: 4718.1730
Training PCC: 0.0008, Validation PCC: -0.0482, Test PCC: -0.0082
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 31865.1398, Validation Total Loss: 4656.8213, Test Total Loss: 4660.1762
Training Regression Loss: 7453.1664, Validation Regression Loss: 4656.8213, Test Regression Loss: 4660.1762
Training Classification Loss: 48823.9469
Training RMSE: 5563.9093, Validation RMSE: 3019.0328, Test RMSE: 3014.8300
Training PCC: 0.0058, Validation PCC: 0.0846, Test PCC: 0.1027
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 34774.5082, Validation Total Loss: 8861.3293, Test Total Loss: 8854.1266
Training Regression Loss: 6766.9389, Validation Regression Loss: 8861.3293, Test Regression Loss: 8854.1266
Training Classification Loss: 56015.1387
Training RMSE: 5426.8774, Validation RMSE: 5369.4982, Test RMSE: 5362.4089
Training PCC: 0.0227, Validation PCC: 0.1181, Test PCC: 0.1250
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 36842.7708, Validation Total Loss: 5217.6905, Test Total Loss: 5218.6158
Training Regression Loss: 7547.4747, Validation Regression Loss: 5217.6905, Test Regression Loss: 5218.6158
Training Classification Loss: 58590.5921
Training RMSE: 4666.0764, Validation RMSE: 3588.9998, Test RMSE: 3593.8057
Training PCC: 0.0177, Validation PCC: 0.0578, Test PCC: 0.0266
Checkpoint saved for epoch 11
Stopping early after 11 epochs
Total training time: 1410.68 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.5_lambda_2_wl100_ol75_normalizedbysubject.


In [11]:
# @title Plotting
import matplotlib.pyplot as plt
def load_most_recent_checkpoint(folder):
    """
    Loads the most recent checkpoint from the given folder and returns it.
    """
    checkpoint_files = [f for f in os.listdir(folder) if f.endswith('.pth')]
    if not checkpoint_files:
        print(f"No checkpoints found in {folder}")
        return None

    # Sort the checkpoint files by epoch (assuming filenames include the epoch number)
    checkpoint_files.sort(key=lambda f: int(f.split('_epoch_')[-1].split('.')[0]))

    # Load the most recent checkpoint
    latest_checkpoint = checkpoint_files[-1]
    checkpoint_path = os.path.join(folder, latest_checkpoint)
    checkpoint = torch.load(checkpoint_path)

    print(f"Loaded checkpoint: {latest_checkpoint} from {folder}")

    # Return the loaded checkpoint, which contains model state, optimizer state, and history
    return checkpoint

def plot_metrics_subplots(aggregated_metrics, folder):
    """Plots train/val/test losses, PCCs, and RMSEs in subplots."""
    epochs = list(range(1, len(aggregated_metrics['train_losses']) + 1))
    line_styles = ['-', '--', '-.', ':']  # Define different line styles for channels

    # New color palette
    train_color = '#17becf'  # Teal blue for Train
    val_color = '#bcbd22'    # Mustard yellow for Val
    test_color = '#e377c2'   # Pastel magenta for Test

    # Create subplots for losses, PCCs, and RMSEs
    fig, axes = plt.subplots(3, 2, figsize=(12, 14))
    fig.suptitle(f'Metrics over Epochs for {folder}', fontsize=16)

    # Plot Losses (subplot 1, 1)
    axes[0, 0].plot(epochs, aggregated_metrics['train_losses'], label='Train Loss', color=train_color, linestyle=line_styles[0])
    axes[0, 0].plot(epochs, aggregated_metrics['val_losses'], label='Val Loss', color=val_color, linestyle=line_styles[0])
    axes[0, 0].plot(epochs, aggregated_metrics['test_losses'], label='Test Loss', color=test_color, linestyle=line_styles[0])
    axes[0, 0].set_title('Losses over Epochs')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()

    # Plot Average PCCs (already averaged, so no need to take mean again) (subplot 1, 2)
    axes[0, 1].plot(epochs, aggregated_metrics['train_pccs'], label='Avg Train PCC', color=train_color, linestyle='-')
    axes[0, 1].plot(epochs, aggregated_metrics['val_pccs'], label='Avg Val PCC', color=val_color, linestyle='-')
    axes[0, 1].plot(epochs, aggregated_metrics['test_pccs'], label='Avg Test PCC', color=test_color, linestyle='-')
    axes[0, 1].set_title('Average PCC over Epochs')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('PCC')
    axes[0, 1].legend()

    # Plot PCCs for each channel (subplot 2, 1)
    for i, style in enumerate(line_styles):
        if len(aggregated_metrics['train_pccs_channelwise'][0]) > i:  # Ensure enough channels exist
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['train_pccs_channelwise']], label=f'Train PCC (Ch {i+1})', color=train_color, linestyle=style)
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['val_pccs_channelwise']], label=f'Val PCC (Ch {i+1})', color=val_color, linestyle=style)
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['test_pccs_channelwise']], label=f'Test PCC (Ch {i+1})', color=test_color, linestyle=style)
    axes[1, 0].set_title('PCCs over Epochs (Per Channel)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('PCC')
    axes[1, 0].legend()

    # Plot Average RMSEs (already averaged, so no need to take mean again) (subplot 2, 2)
    axes[1, 1].plot(epochs, aggregated_metrics['train_rmses'], label='Avg Train RMSE', color=train_color, linestyle='-')
    axes[1, 1].plot(epochs, aggregated_metrics['val_rmses'], label='Avg Val RMSE', color=val_color, linestyle='-')
    axes[1, 1].plot(epochs, aggregated_metrics['test_rmses'], label='Avg Test RMSE', color=test_color, linestyle='-')
    axes[1, 1].set_title('Average RMSE over Epochs')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('RMSE')
    axes[1, 1].legend()

    # Plot RMSEs for each channel (subplot 3, 1)
    for i, style in enumerate(line_styles):
        if len(aggregated_metrics['train_rmses_channelwise'][0]) > i:
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['train_rmses_channelwise']], label=f'Train RMSE (Ch {i+1})', color=train_color, linestyle=style)
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['val_rmses_channelwise']], label=f'Val RMSE (Ch {i+1})', color=val_color, linestyle=style)
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['test_rmses_channelwise']], label=f'Test RMSE (Ch {i+1})', color=test_color, linestyle=style)
    axes[2, 0].set_title('RMSEs over Epochs (Per Channel)')
    axes[2, 0].set_xlabel('Epoch')
    axes[2, 0].set_ylabel('RMSE')
    axes[2, 0].legend()

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to make room for the suptitle
    plt.show()

def plot_comparison_bar_graphs_with_limits(aggregated_metrics_dict):
    """
    Plots two bar graphs comparing the best RMSE and PCC for each model, showing corresponding PCC/RMSE from the same epoch,
    and plots average test RMSEs across epochs for each model.
    """

    # Dictionary to hold the best RMSE and PCC along with their associated metrics at the same epoch
    best_metrics = {}

    # Iterate through the aggregated metrics and extract the best RMSE and its corresponding PCC, and vice versa
    for model_name, metrics in aggregated_metrics_dict.items():
        # Extract pre-averaged RMSEs and PCCs for the test data
        test_rmses = metrics['test_rmses']  # Pre-averaged RMSE values
        test_pccs = metrics['test_pccs']    # Pre-averaged PCC values

        # Find the epoch index for the best RMSE and the corresponding PCC at that epoch
        best_rmse_epoch = np.argmin(test_rmses)
        best_rmse = test_rmses[best_rmse_epoch]
        corresponding_pcc_for_best_rmse = test_pccs[best_rmse_epoch]

        # Find the epoch index for the best PCC and the corresponding RMSE at that epoch
        best_pcc_epoch = np.argmax(test_pccs)
        best_pcc = test_pccs[best_pcc_epoch]
        corresponding_rmse_for_best_pcc = test_rmses[best_pcc_epoch]

        # Store in a dictionary for plotting
        best_metrics[model_name] = {
            'best_rmse': best_rmse,
            'corresponding_pcc_for_best_rmse': corresponding_pcc_for_best_rmse,
            'best_pcc': best_pcc,
            'corresponding_rmse_for_best_pcc': corresponding_rmse_for_best_pcc
        }

    # Convert to lists for sorting
    models = list(best_metrics.keys())
    best_rmses = [best_metrics[model]['best_rmse'] for model in models]
    corresponding_pccs = [best_metrics[model]['corresponding_pcc_for_best_rmse'] for model in models]
    best_pccs = [best_metrics[model]['best_pcc'] for model in models]
    corresponding_rmses = [best_metrics[model]['corresponding_rmse_for_best_pcc'] for model in models]

    # Sort by RMSE and PCC
    sorted_by_rmse = sorted(zip(models, best_rmses, corresponding_pccs), key=lambda x: x[1], reverse=False)
    sorted_by_pcc = sorted(zip(models, best_pccs, corresponding_rmses), key=lambda x: x[1], reverse=True)

    # Unpack sorted values
    sorted_models_rmse, sorted_rmses, sorted_pccs_for_rmse = zip(*sorted_by_rmse)
    sorted_models_pcc, sorted_pccs, sorted_rmses_for_pcc = zip(*sorted_by_pcc)

    # Ensure consistent colors across both graphs
    colors = sns.color_palette("husl", len(models))
    model_colors_rmse = {model: colors[i] for i, model in enumerate(sorted_models_rmse)}
    model_colors_pcc = {model: model_colors_rmse[model] for model in sorted_models_pcc}  # Keep the colors consistent

    # Plot RMSE Bar Graph (With Corresponding PCC values)
    plt.figure(figsize=(10, 6))
    bars_rmse = plt.bar(sorted_models_rmse, sorted_rmses, color=[model_colors_rmse[model] for model in sorted_models_rmse])
    plt.title('Best RMSE per Model (Sorted)')
    plt.xlabel('Model')
    plt.ylabel('Best RMSE')
    plt.ylim(0, 25)  # Set an appropriate y-limit for RMSEs
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Add the actual RMSE values and corresponding PCC on top of the bars
    for i, bar in enumerate(bars_rmse):
        yval_rmse = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval_rmse + 0.5, f'{yval_rmse:.2f}', ha='center', va='bottom')
        plt.text(bar.get_x() + bar.get_width()/2, yval_rmse - 2, f'PCC: {sorted_pccs_for_rmse[i]:.2f}', ha='center', va='top', color='black')

    plt.show()

    # Plot PCC Bar Graph (With Corresponding RMSE values)
    plt.figure(figsize=(10, 6))
    bars_pcc = plt.bar(sorted_models_pcc, sorted_pccs, color=[model_colors_pcc[model] for model in sorted_models_pcc])
    plt.title('Best PCC per Model (Sorted)')
    plt.xlabel('Model')
    plt.ylabel('Best PCC')
    plt.ylim(0, 1)  # Set an appropriate y-limit for PCCs
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Add the actual PCC values and corresponding RMSE on top of the bars
    for i, bar in enumerate(bars_pcc):
        yval_pcc = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval_pcc + 0.02, f'{yval_pcc:.2f}', ha='center', va='bottom')
        plt.text(bar.get_x() + bar.get_width()/2, yval_pcc - 0.08, f'RMSE: {sorted_rmses_for_pcc[i]:.2f}', ha='center', va='top', color='black')

    plt.show()

    # Plot the average test RMSEs across epochs for each model
    plt.figure(figsize=(10, 6))
    for i, model in enumerate(models):
        avg_test_rmses = aggregated_metrics_dict[model]['test_rmses']  # Directly plot pre-averaged RMSEs
        plt.plot(avg_test_rmses, label=f'{model} Average Test RMSE', color=colors[i])

    plt.title('Average Test RMSE per Epoch for All Models')
    plt.xlabel('Epoch')
    plt.ylabel('Average RMSE')
    plt.legend()
    plt.tight_layout()
    plt.show()


# Initialize an empty dictionary to hold metrics for all models
aggregated_metrics_dict = {}

# Loop over each model and load the checkpoint
for model_name, model_config in model_configs.items():
    # Construct the folder path based on the filename
    folder = f"/content/MyDrive/MyDrive/models/{model_name}"
    checkpoint = load_most_recent_checkpoint(folder)

    if checkpoint:
        aggregated_metrics = checkpoint['history']

        # Store the metrics for comparison later
        aggregated_metrics_dict[model_name] = {
            'test_rmses': aggregated_metrics['test_rmses'],
            'test_pccs': aggregated_metrics['test_pccs'],
        }

        # Plot metrics for this model
        plot_metrics_subplots(aggregated_metrics, model_name)

# After collecting all model metrics, plot the comparison bar graphs
if aggregated_metrics_dict:
    plot_comparison_bar_graphs_with_limits(aggregated_metrics_dict)


Output hidden; open in https://colab.research.google.com to view.