In [21]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Drive already mounted at /content/MyDrive; to attempt to forcibly remount, call drive.mount("/content/MyDrive", force_remount=True).


In [22]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [23]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
        """
        Extracts data for the given channels from the dataset (whether it's a compound dataset or simple dataset),
        and interpolates missing values (NaNs) in each channel data.
        """
        extracted_data = []
        column_names = []

        if isinstance(session_data, h5py.Dataset):
            # Check if the dataset has named fields (compound dataset)
            if session_data.dtype.names:
                # Compound dataset, use the named fields
                column_names = session_data.dtype.names
                for channel in channels:
                    if channel in column_names:
                        channel_data = session_data[channel][:]  # Access by field name
                        # Convert the data to a numeric type (float), if necessary
                        channel_data = pd.to_numeric(channel_data, errors='coerce')
                        # Interpolate NaN values
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                    else:
                        print(f"Channel {channel} not found in compound dataset.")
            else:
                # Simple dataset, use index-based access (no named fields)
                column_names = session_data.attrs.get('column_names', [])

                # Cast column_names to a list to allow 'index' lookup
                column_names = list(column_names)
                new_column_names = []

                assert len(column_names) > 0, "column_names not found in dataset attributes"
                for channel in channels:
                    if channel in column_names:
                        col_idx = column_names.index(channel)
                        new_column_names.append(channel)
                        channel_data = session_data[:, col_idx]  # Access by column index

                        # Convert the data to a numeric type (float), if necessary
                        channel_data = pd.to_numeric(channel_data, errors='coerce')

                        # Interpolate NaN values
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                    else:
                        print(f"Channel {channel} not found in session data.")

        return np.array(extracted_data), new_column_names

In [24]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

class ImuJointPairSubjectDataset(ImuJointPairDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Create a mapping from subject strings (e.g., 'subject_1') to class indices
        self.subject_mapping = {subject: i for i, subject in enumerate(sorted(subjects))}

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx, 0]

        # Extract the subject ID from the filename
        filename_base = os.path.basename(filename)
        filename_without_ext = os.path.splitext(filename_base)[0]
        parts = filename_without_ext.split('_')

        # Construct subject string in the format 'subject_x'
        try:
            subject_index = parts.index('subject')
            subject_str = f"subject_{parts[subject_index + 1]}"  # Create string like 'subject_1'
        except ValueError:
            raise ValueError(f"'subject' not found in filename: {filename}")
        except IndexError:
            raise ValueError(f"Subject ID not found after 'subject' in filename: {filename}")

        # Map subject_str to class index
        if subject_str not in self.subject_mapping:
            raise ValueError(f"Subject ID {subject_str} not found in training set.")

        mapped_class = self.subject_mapping[subject_str]

        # Return class index instead of one-hot encoding
        return imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class


def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairSubjectDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairSubjectDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=0,  # Typically no overlap in test set
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader




In [25]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used


from torch.autograd import Function

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_grl):
        ctx.lambda_grl = lambda_grl
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        lambda_grl = ctx.lambda_grl
        grad_input = grad_output.neg() * lambda_grl
        return grad_input, None

def GradientReversalLayer(lambda_grl):
    return GradientReversalFunction.apply

class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100,num_subjects=12,lambda_grl=1.0):
        super(teacher, self).__init__()

        self.lambda_grl = lambda_grl
        self.w=w
        self.encoder_1_acc=Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr=Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg=Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc=Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr=Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg=Encoder_2(input_emg, drop_prob)

        self.BN_acc= nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr= nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg= nn.BatchNorm1d(input_emg, affine=False)


        self.fc = nn.Linear(2*3*128+128,3)
        self.fc_classifier = nn.Linear(2*3*128+128,num_subjects)
        self.dropout=nn.Dropout(p=.05)

        self.gate_1=GatingModule(128)
        self.gate_2=GatingModule(128)
        self.gate_3=GatingModule(128)

        self.fc_kd = nn.Linear(3*128, 2*128)

               # Define the gating network
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention=nn.MultiheadAttention(3*128,4,batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128*3, 3*128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2*3*128+128, 2*3*128+128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)


    def forward(self, x_acc, x_gyr, x_emg):

        x_acc_1=x_acc.view(x_acc.size(0)*x_acc.size(1),x_acc.size(-1))
        x_gyr_1=x_gyr.view(x_gyr.size(0)*x_gyr.size(1),x_gyr.size(-1))
        x_emg_1=x_emg.view(x_emg.size(0)*x_emg.size(1),x_emg.size(-1))

        x_acc_1=self.BN_acc(x_acc_1)
        x_gyr_1=self.BN_gyr(x_gyr_1)
        x_emg_1=self.BN_emg(x_emg_1)

        x_acc_2=x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2=x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2=x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Pass through Encoder 1 for each modality and capture hidden states
        x_acc_1, (h_acc_1, _) = self.encoder_1_acc(x_acc_2)
        x_gyr_1, (h_gyr_1, _) = self.encoder_1_gyr(x_gyr_2)
        x_emg_1, (h_emg_1, _) = self.encoder_1_emg(x_emg_2)

        # Pass through Encoder 2 for each modality and capture hidden states
        x_acc_2, (h_acc_2, _) = self.encoder_2_acc(x_acc_2)
        x_gyr_2, (h_gyr_2, _) = self.encoder_2_gyr(x_gyr_2)
        x_emg_2, (h_emg_2, _) = self.encoder_2_emg(x_emg_2)

        # x_acc=torch.cat((x_acc_1,x_acc_2),dim=-1)
        # x_gyr=torch.cat((x_gyr_1,x_gyr_2),dim=-1)
        # x_emg=torch.cat((x_emg_1,x_emg_2),dim=-1)

        x_acc=self.gate_1(x_acc_1,x_acc_2)
        x_gyr=self.gate_2(x_gyr_1,x_gyr_2)
        x_emg=self.gate_3(x_emg_1,x_emg_2)

        x=torch.cat((x_acc,x_gyr,x_emg),dim=-1)
        x_kd=self.fc_kd(x)


        out_1, attn_output_weights=self.attention(x,x,x)

        gating_weights = self.gating_net(x)
        out_2=gating_weights*x

        weights_1 = self.weighted_feat(x[:,:,0:128])
        weights_2 = self.weighted_feat(x[:,:,128:2*128])
        weights_3 = self.weighted_feat(x[:,:,2*128:3*128])
        x_1=weights_1*x[:,:,0:128]
        x_2=weights_2*x[:,:,128:2*128]
        x_3=weights_3*x[:,:,2*128:3*128]
        out_3=x_1+x_2+x_3

        out=torch.cat((out_1,out_2,out_3),dim=-1)

        gating_weights_1 = self.gating_net_1(out)
        out=gating_weights_1*out

        #print shape of out
        out_task=self.fc(out)

        grl = GradientReversalLayer(self.lambda_grl)
        out_rev = grl(out, self.lambda_grl)  # Apply GRL to reverse gradients

        # Domain classification with reversed gradients
        out_rev = out_rev.permute(0, 2, 1)  # Permute to [batch_size, features, time]
        out_rev = torch.nn.functional.adaptive_avg_pool1d(out_rev, 1)  # Pooling
        out_rev = out_rev.squeeze(-1)  # Shape: [batch_size, features]

        out_classifier = self.fc_classifier(out_rev)  # Domain classifier output


        #print(out.shape)
        return out_task, out_classifier, (h_acc_1, h_acc_2, h_gyr_1, h_gyr_2, h_emg_1, h_emg_2)




In [26]:
# @title Loss Functions
import statistics

class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

def compute_biomechanical_loss(predicted_angles):
    """
    Compute the biomechanical loss (L_bio) to enforce joint limits
    for the three joints, using index-based access.
    """
    # Define joint limits for each channel (in degrees)
    min_limits = torch.tensor([0, -90, -120], device=predicted_angles.device)
    max_limits = torch.tensor([150, 180, 90], device=predicted_angles.device)

    while min_limits.dim() < predicted_angles.dim():
        min_limits = min_limits.unsqueeze(0)
    while max_limits.dim() < predicted_angles.dim():
        max_limits = max_limits.unsqueeze(0)

    # Now min_limits and max_limits have shape [1, 1, num_joints] and will broadcast correctly
    lower_violation = torch.relu(min_limits - predicted_angles)
    upper_violation = torch.relu(predicted_angles - max_limits)

    L_bio = torch.mean(lower_violation + upper_violation)

    return L_bio


class BoundRmseLoss(nn.Module):
    def __init__(self, lambda_bio=0.1):
        super(BoundRmseLoss, self).__init__()
        self.lambda_bio = lambda_bio
        self.rmse_loss = RMSELoss()  # Using your existing RMSELoss class

    def forward(self, output, target):
        # Compute RMSE loss
        L_data = self.rmse_loss(output, target)

        # Compute biomechanical loss
        L_bio = compute_biomechanical_loss(output)

        # Total loss
        total_loss = L_data + self.lambda_bio * L_bio

        return total_loss

In [27]:
# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG,_) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse



def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(
    device,
    train_loader,
    val_loader,
    test_loader,
    learn_rate,
    epochs,
    model,
    filename,
    loss_function,
    optimizer=None,
    l1_lambda=None,
    train_from_last_epoch=False,
    curriculum_loader=None,
    num_classes=12,  # Add num_classes parameter
    alpha=1.0,       # Weighting factor for classification loss
):
    model.to(device)
    criterion_regression = loss_function
    criterion_classification = nn.CrossEntropyLoss(ignore_index=-1)

    if optimizer is None:
        # Create a default Adam optimizer if none is passed
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    train_pccs_channelwise = []
    val_pccs_channelwise = []
    test_pccs_channelwise = []

    train_rmses_channelwise = []
    val_rmses_channelwise = []
    test_rmses_channelwise = []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    if train_from_last_epoch and os.path.exists(checkpoint_path):
        # Scan for the latest saved checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch number
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']  # Continue from the next epoch

            # Load the history from checkpoint
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
            train_pccs = checkpoint['history']['train_pccs']
            val_pccs = checkpoint['history']['val_pccs']
            test_pccs = checkpoint['history']['test_pccs']
            train_rmses = checkpoint['history']['train_rmses']
            val_rmses = checkpoint['history']['val_rmses']
            test_rmses = checkpoint['history']['test_rmses']
            train_pccs_channelwise = checkpoint['history']['train_pccs_channelwise']
            val_pccs_channelwise = checkpoint['history']['val_pccs_channelwise']
            test_pccs_channelwise = checkpoint['history']['test_pccs_channelwise']
            train_rmses_channelwise = checkpoint['history']['train_rmses_channelwise']
            val_rmses_channelwise = checkpoint['history']['val_rmses_channelwise']
            test_rmses_channelwise = checkpoint['history']['test_rmses_channelwise']

            if 'curriculum_schedule' in checkpoint:
                curriculum_loader.curriculum_schedule = checkpoint['curriculum_schedule']  # Load saved curriculum schedule
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(last_epoch, epochs):
        epoch_start_time = time.time()
        model.train()

        if curriculum_loader:
            curriculum_loader.update_epoch(epoch)
            train_loader, val_loader, test_loader = curriculum_loader.get_loaders()

        # Track total loss
        epoch_train_total_loss = 0.0
        epoch_train_regression_loss = 0.0
        epoch_train_classification_loss = 0.0

        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        for i, (data_acc, data_gyr, target, data_EMG, subject_id) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            #print dimensions of each and subject id
            # print("acc",data_acc.shape)
            # print("gyro",data_gyr.shape)
            # print("target",target.shape)
            # print("emg",data_EMG.shape)
            # print("subject ids",subject_id)

            # Move data to device
            data_acc = data_acc.to(device).float()
            data_gyr = data_gyr.to(device).float()
            data_EMG = data_EMG.to(device).float()
            target = target.to(device).float()
            subject_id = subject_id.to(device).long() - 1  # Adjust subject IDs to start from 0

            # Forward pass
            output, out_classifier, _ = model(data_acc, data_gyr, data_EMG)

            # print(f"Output shape: {output.shape}")
            # print(f"Target shape: {target.shape}")
            # Compute regression loss
            loss_regression = criterion_regression(output, target)

            # Compute classification loss
            # Ensure shapes are correct
            # print(f"Out Classifier shape: {out_classifier.shape}")
            # print(f"Subject ID shape: {subject_id.shape}")
            # print(f"Subject ID dtype: {subject_id.dtype}")
            # print(f"Out Classifier dtype: {out_classifier.dtype}")
            # print(f"subject ids",subject_id)
            # print(f"out classifier",out_classifier)
            loss_classification = criterion_classification(out_classifier, subject_id)

            # Combine losses
            total_loss = loss_regression + alpha * loss_classification

            # Apply L1 regularization if specified
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss += l1_lambda * l1_norm

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()



            # Detach tensors and move to CPU to prevent issues with gradient computation
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(
                output.detach().cpu().numpy(),
                target.detach().cpu().numpy(),
                len(config.channels_joints),
                print_losses=False
            )

            # Update metrics
            epoch_train_total_loss += total_loss.item()
            epoch_train_regression_loss += loss_regression.item()
            epoch_train_classification_loss += loss_classification.item()

            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        # Calculate average losses
        avg_train_total_loss = epoch_train_total_loss / len(train_loader)
        avg_train_regression_loss = epoch_train_regression_loss / len(train_loader)
        avg_train_classification_loss = epoch_train_classification_loss / len(train_loader)

        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_total_loss)
        train_pccs.append(np.mean(avg_train_pcc))  # Overall average PCC
        train_rmses.append(np.mean(avg_train_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        train_pccs_channelwise.append(avg_train_pcc)  # Per channel
        train_rmses_channelwise.append(avg_train_rmse)  # Per channel

        # Evaluate on validation set
        avg_val_total_loss, avg_val_pcc, avg_val_rmse = evaluate_model(
            device,
            model,
            val_loader,
            criterion_regression
        )

        val_losses.append(avg_val_total_loss)
        val_pccs.append(np.mean(avg_val_pcc))  # Overall average PCC
        val_rmses.append(np.mean(avg_val_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        val_pccs_channelwise.append(avg_val_pcc)  # Per channel
        val_rmses_channelwise.append(avg_val_rmse)  # Per channel

        # Evaluate on test set
        avg_test_total_loss, avg_test_pcc, avg_test_rmse = evaluate_model(
            device,
            model,
            test_loader,
            criterion_regression
        )

        test_losses.append(avg_test_total_loss)
        test_pccs.append(np.mean(avg_test_pcc))  # Overall average PCC
        test_rmses.append(np.mean(avg_test_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        test_pccs_channelwise.append(avg_test_pcc)  # Per channel
        test_rmses_channelwise.append(avg_test_rmse)  # Per channel

        print(f"Epoch: {epoch + 1}, Training Total Loss: {avg_train_total_loss:.4f}, Validation Total Loss: {avg_val_total_loss:.4f}, Test Total Loss: {avg_test_total_loss:.4f}")
        print(f"Training Regression Loss: {avg_train_regression_loss:.4f}, Validation Regression Loss: {avg_val_total_loss:.4f}, Test Regression Loss: {avg_test_total_loss:.4f}")
        print(f"Training Classification Loss: {avg_train_classification_loss:.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse):.4f}, Validation RMSE: {np.mean(avg_val_rmse):.4f}, Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc):.4f}, Validation PCC: {np.mean(avg_val_pcc):.4f}, Test PCC: {np.mean(avg_test_pcc):.4f}")

        # Save checkpoint, including curriculum schedule
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save checkpoint with the curriculum schedule
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'test_losses': test_losses,
            'train_pccs': train_pccs,
            'val_pccs': val_pccs,
            'test_pccs': test_pccs,
            'train_rmses': train_rmses,
            'val_rmses': val_rmses,
            'test_rmses': test_rmses,
            'train_pccs_channelwise': train_pccs_channelwise,
            'val_pccs_channelwise': val_pccs_channelwise,
            'test_pccs_channelwise': test_pccs_channelwise,
            'train_rmses_channelwise': train_rmses_channelwise,
            'val_rmses_channelwise': val_rmses_channelwise,
            'test_rmses_channelwise': test_rmses_channelwise,
            'train_classification_loss': avg_train_classification_loss
        }

        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth",
            train_loss=avg_train_total_loss,
            val_loss=avg_val_total_loss,
            test_loss=avg_test_total_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            },
            history=history,  # Save history in the checkpoint
            curriculum_schedule=curriculum_loader.curriculum_schedule if curriculum_loader else None  # Save curriculum schedule
        )

        # Early stopping logic
        if avg_val_total_loss < best_val_loss:
            best_val_loss = avg_val_total_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses







In [28]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100,num_subjects=None,lambda_grl=None):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w,num_subjects=num_subjects,lambda_grl=lambda_grl)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model


In [29]:

import os
import h5py
import csv
from tqdm.notebook import tqdm
import pandas as pd
def create_curriculum_schedule(all_subjects, num_subjects, epochs):
    return [(epoch, sorted(random.sample(all_subjects, num_subjects))) for epoch in range(epochs)]

# Define the list of all subjects
all_subjects = ['subject_1','subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6',
                'subject_7', 'subject_8', 'subject_9', 'subject_10',
                'subject_11', 'subject_12', 'subject_13']

# Define the base model and save its weights
input_acc, input_gyr, input_emg = 18, 18, 3  # Example inputs based on your setup
base_model = teacher(input_acc, input_gyr, input_emg)
base_weights_path = 'base_teacher_weights.pth'
torch.save(base_model.state_dict(), base_weights_path)
window_overlap = 0
batch_size = 64
curriculum_epochs = 20

# Create model configurations for each subject as the test subject
model_configs = {}

test_subject = 'subject_1'
train_subjects = [subject for subject in all_subjects if subject not in test_subject]

# Define the number of subjects/classes
num_classes = 12

# List of different weights for the classification loss
alpha_values = [0.1, 0.2, 0.3, 0.4, 0.5]
lambda_values = [0.5, 0.75, 1, 1.25, 1.5, 2]
for lamda in lambda_values:
  for alpha in alpha_values:
      model_name = f'TeacherModel_DomainInvariant_alpha_{alpha}_lambda_{lamda}_wl{100}_ol{75}'
      model_configs[model_name] = {
          'model': create_teacher_model(
              input_acc=input_acc,
              input_gyr=input_gyr,
              input_emg=input_emg,
              w=100,
              num_subjects=num_classes,  # Ensure the model includes the classification head
              lambda_grl=lamda
          ),
          'loss': RMSELoss(),
          'loaders': create_base_data_loaders(
              config=config,
              train_subjects=train_subjects,
              test_subjects=[test_subject],
              window_length=100,
              window_overlap=75,
              batch_size=batch_size
          ),
          'epochs': curriculum_epochs,
          'use_curriculum': False,
          'alpha': alpha,  # Weighting factor for classification loss
          'num_classes': num_classes,  # Number of classes for the classification task
      }

# Output the model configurations for debugging (optional)
for model_name, model_conf in model_configs.items():
    print(f"Model: {model_name}")


Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping re

In [None]:
 # @title run models

#clear gpu memory
torch.cuda.empty_cache()

def ask_run():
    response = input("Do you want to run models? (yes/no): ").strip().lower()
    if response in ['yes', 'y']:
        return True
    elif response in ['no', 'n']:
        return False
    else:
        print("Invalid input. Please enter 'yes' or 'no'.")
        return ask_run()  # Recursively ask again until valid input is given

run = ask_run()

if run:
    for model_name, model_config in model_configs.items():
        model = model_config['model']
        loss_function = model_config['loss']  # May be None, as loss functions are defined within train_teacher
        epochs = model_config.get("epochs", 100)
        device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        learn_rate = model_config.get("learn_rate", 0.001)
        use_curriculum = model_config.get("use_curriculum", False)

        optimizer = model_config.get("optimizer", None)
        l1_lambda = model_config.get("l1_lambda", None)
        alpha = model_config.get('alpha', 1.0)
        num_classes = model_config.get('num_classes', 12)  # Default to 12 if not specified

        print(f"Running model: {model_name} with alpha: {alpha}")

        if use_curriculum:
            curriculum_loader = model_config['loader']  # Get the CurriculumDataLoader
            train_loader, val_loader, test_loader = None, None, None  # Curriculum will handle loading per epoch

            model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, \
            train_rmses, val_rmses, test_rmses = train_teacher(
                device=device,
                train_loader=train_loader,  # Placeholders; curriculum will manage loaders dynamically
                val_loader=val_loader,
                test_loader=test_loader,
                learn_rate=learn_rate,
                epochs=epochs,
                model=model,
                filename=model_name,
                loss_function=loss_function,
                curriculum_loader=curriculum_loader,  # Pass the curriculum loader here
                optimizer=optimizer,
                l1_lambda=l1_lambda,
                train_from_last_epoch=model_config.get("train_from_last_epoch", False),
                alpha=alpha,
                num_classes=num_classes,
            )
        else:
            # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
            train_loader, val_loader, test_loader = model_config['loaders']

            model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, \
            train_rmses, val_rmses, test_rmses = train_teacher(
                device=device,
                train_loader=train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                learn_rate=learn_rate,
                epochs=epochs,
                model=model,
                filename=model_name,
                loss_function=loss_function,
                optimizer=optimizer,
                l1_lambda=l1_lambda,
                train_from_last_epoch=model_config.get("train_from_last_epoch", False),
                alpha=alpha,
                num_classes=num_classes,
            )

        print(f"Finished training for {model_name}.")


Do you want to run models? (yes/no): yes
Running model: TeacherModel_DomainInvariant_alpha_0.1_lambda_0.5_wl100_ol75 with alpha: 0.1
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.2548, Validation Total Loss: 11.1596, Test Total Loss: 21.4153
Training Regression Loss: 19.0231, Validation Regression Loss: 11.1596, Test Regression Loss: 21.4153
Training Classification Loss: 2.3179
Training RMSE: 18.5470, Validation RMSE: 10.7235, Test RMSE: 20.1029
Training PCC: 0.7913, Validation PCC: 0.9461, Test PCC: 0.6091
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 10.9110, Validation Total Loss: 9.5638, Test Total Loss: 20.7037
Training Regression Loss: 10.7365, Validation Regression Loss: 9.5638, Test Regression Loss: 20.7037
Training Classification Loss: 1.7448
Training RMSE: 10.3365, Validation RMSE: 9.1727, Test RMSE: 20.1311
Training PCC: 0.9532, Validation PCC: 0.9665, Test PCC: 0.6586
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.5587, Validation Total Loss: 8.7143, Test Total Loss: 19.5647
Training Regression Loss: 9.4149, Validation Regression Loss: 8.7143, Test Regression Loss: 19.5647
Training Classification Loss: 1.4382
Training RMSE: 8.9902, Validation RMSE: 8.3504, Test RMSE: 18.2970
Training PCC: 0.9654, Validation PCC: 0.9726, Test PCC: 0.6619
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 8.6593, Validation Total Loss: 8.2110, Test Total Loss: 21.4112
Training Regression Loss: 8.5351, Validation Regression Loss: 8.2110, Test Regression Loss: 21.4112
Training Classification Loss: 1.2419
Training RMSE: 8.1680, Validation RMSE: 7.6747, Test RMSE: 19.6551
Training PCC: 0.9710, Validation PCC: 0.9761, Test PCC: 0.6799
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.1354, Validation Total Loss: 7.8292, Test Total Loss: 20.6012
Training Regression Loss: 8.0207, Validation Regression Loss: 7.8292, Test Regression Loss: 20.6012
Training Classification Loss: 1.1466
Training RMSE: 7.6647, Validation RMSE: 7.5019, Test RMSE: 19.5454
Training PCC: 0.9741, Validation PCC: 0.9795, Test PCC: 0.6981
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 7.7122, Validation Total Loss: 7.2895, Test Total Loss: 22.0343
Training Regression Loss: 7.6042, Validation Regression Loss: 7.2895, Test Regression Loss: 22.0343
Training Classification Loss: 1.0797
Training RMSE: 7.2672, Validation RMSE: 6.8230, Test RMSE: 20.4226
Training PCC: 0.9767, Validation PCC: 0.9820, Test PCC: 0.6895
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.2717, Validation Total Loss: 7.2425, Test Total Loss: 18.8983
Training Regression Loss: 7.1702, Validation Regression Loss: 7.2425, Test Regression Loss: 18.8983
Training Classification Loss: 1.0149
Training RMSE: 6.8625, Validation RMSE: 6.9871, Test RMSE: 17.4494
Training PCC: 0.9791, Validation PCC: 0.9830, Test PCC: 0.6952
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.1834, Validation Total Loss: 6.2860, Test Total Loss: 18.3986
Training Regression Loss: 7.0820, Validation Regression Loss: 6.2860, Test Regression Loss: 18.3986
Training Classification Loss: 1.0135
Training RMSE: 6.7571, Validation RMSE: 6.0052, Test RMSE: 17.0931
Training PCC: 0.9796, Validation PCC: 0.9846, Test PCC: 0.7275
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 6.9615, Validation Total Loss: 6.5750, Test Total Loss: 19.6813
Training Regression Loss: 6.8583, Validation Regression Loss: 6.5750, Test Regression Loss: 19.6813
Training Classification Loss: 1.0318
Training RMSE: 6.5536, Validation RMSE: 6.2024, Test RMSE: 18.3333
Training PCC: 0.9808, Validation PCC: 0.9841, Test PCC: 0.6982
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 6.7279, Validation Total Loss: 6.0416, Test Total Loss: 18.5861
Training Regression Loss: 6.6212, Validation Regression Loss: 6.0416, Test Regression Loss: 18.5861
Training Classification Loss: 1.0671
Training RMSE: 6.3183, Validation RMSE: 5.7313, Test RMSE: 17.2479
Training PCC: 0.9819, Validation PCC: 0.9852, Test PCC: 0.7282
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.4678, Validation Total Loss: 5.8992, Test Total Loss: 20.0637
Training Regression Loss: 6.3559, Validation Regression Loss: 5.8992, Test Regression Loss: 20.0637
Training Classification Loss: 1.1195
Training RMSE: 6.0732, Validation RMSE: 5.6577, Test RMSE: 18.9161
Training PCC: 0.9833, Validation PCC: 0.9870, Test PCC: 0.6871
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.3139, Validation Total Loss: 5.7316, Test Total Loss: 19.7269
Training Regression Loss: 6.1932, Validation Regression Loss: 5.7316, Test Regression Loss: 19.7269
Training Classification Loss: 1.2080
Training RMSE: 5.9272, Validation RMSE: 5.3913, Test RMSE: 18.3521
Training PCC: 0.9839, Validation PCC: 0.9877, Test PCC: 0.6948
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.1619, Validation Total Loss: 5.6555, Test Total Loss: 19.1208
Training Regression Loss: 6.0299, Validation Regression Loss: 5.6555, Test Regression Loss: 19.1208
Training Classification Loss: 1.3202
Training RMSE: 5.7714, Validation RMSE: 5.4083, Test RMSE: 17.8443
Training PCC: 0.9847, Validation PCC: 0.9871, Test PCC: 0.7055
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.0493, Validation Total Loss: 5.7763, Test Total Loss: 18.9780
Training Regression Loss: 5.9073, Validation Regression Loss: 5.7763, Test Regression Loss: 18.9780
Training Classification Loss: 1.4193
Training RMSE: 5.6597, Validation RMSE: 5.4596, Test RMSE: 17.9765
Training PCC: 0.9853, Validation PCC: 0.9887, Test PCC: 0.6995
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.0820, Validation Total Loss: 5.5179, Test Total Loss: 20.0041
Training Regression Loss: 5.9324, Validation Regression Loss: 5.5179, Test Regression Loss: 20.0041
Training Classification Loss: 1.4968
Training RMSE: 5.6736, Validation RMSE: 5.2748, Test RMSE: 18.7806
Training PCC: 0.9853, Validation PCC: 0.9884, Test PCC: 0.6934
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 5.9522, Validation Total Loss: 5.5142, Test Total Loss: 19.7751
Training Regression Loss: 5.7962, Validation Regression Loss: 5.5142, Test Regression Loss: 19.7751
Training Classification Loss: 1.5594
Training RMSE: 5.5470, Validation RMSE: 5.2462, Test RMSE: 18.3099
Training PCC: 0.9860, Validation PCC: 0.9890, Test PCC: 0.6887
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.8670, Validation Total Loss: 5.2720, Test Total Loss: 20.1709
Training Regression Loss: 5.6990, Validation Regression Loss: 5.2720, Test Regression Loss: 20.1709
Training Classification Loss: 1.6803
Training RMSE: 5.4578, Validation RMSE: 5.1096, Test RMSE: 18.9064
Training PCC: 0.9862, Validation PCC: 0.9890, Test PCC: 0.6882
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.6775, Validation Total Loss: 4.9815, Test Total Loss: 19.5559
Training Regression Loss: 5.4979, Validation Regression Loss: 4.9815, Test Regression Loss: 19.5559
Training Classification Loss: 1.7963
Training RMSE: 5.2679, Validation RMSE: 4.7457, Test RMSE: 18.2358
Training PCC: 0.9870, Validation PCC: 0.9901, Test PCC: 0.6976
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.5947, Validation Total Loss: 5.1500, Test Total Loss: 19.9615
Training Regression Loss: 5.4071, Validation Regression Loss: 5.1500, Test Regression Loss: 19.9615
Training Classification Loss: 1.8761
Training RMSE: 5.1995, Validation RMSE: 4.9922, Test RMSE: 18.7827
Training PCC: 0.9873, Validation PCC: 0.9899, Test PCC: 0.6924
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.5815, Validation Total Loss: 5.1490, Test Total Loss: 19.7409
Training Regression Loss: 5.3883, Validation Regression Loss: 5.1490, Test Regression Loss: 19.7409
Training Classification Loss: 1.9320
Training RMSE: 5.1707, Validation RMSE: 4.9207, Test RMSE: 18.3617
Training PCC: 0.9875, Validation PCC: 0.9893, Test PCC: 0.6898
Checkpoint saved for epoch 20
Total training time: 2951.99 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.1_lambda_0.5_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_0.2_lambda_0.5_wl100_ol75 with alpha: 0.2
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.3581, Validation Total Loss: 11.7214, Test Total Loss: 21.7441
Training Regression Loss: 18.8768, Validation Regression Loss: 11.7214, Test Regression Loss: 21.7441
Training Classification Loss: 2.4065
Training RMSE: 18.3709, Validation RMSE: 11.2669, Test RMSE: 20.1916
Training PCC: 0.8014, Validation PCC: 0.9422, Test PCC: 0.6079
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 10.8870, Validation Total Loss: 9.8004, Test Total Loss: 21.1227
Training Regression Loss: 10.5080, Validation Regression Loss: 9.8004, Test Regression Loss: 21.1227
Training Classification Loss: 1.8948
Training RMSE: 10.0905, Validation RMSE: 9.2799, Test RMSE: 19.7118
Training PCC: 0.9553, Validation PCC: 0.9638, Test PCC: 0.6411
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.5547, Validation Total Loss: 8.6757, Test Total Loss: 21.5598
Training Regression Loss: 9.2091, Validation Regression Loss: 8.6757, Test Regression Loss: 21.5598
Training Classification Loss: 1.7278
Training RMSE: 8.7969, Validation RMSE: 8.2455, Test RMSE: 19.7524
Training PCC: 0.9668, Validation PCC: 0.9701, Test PCC: 0.7234
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 8.9213, Validation Total Loss: 8.7706, Test Total Loss: 19.3615
Training Regression Loss: 8.5921, Validation Regression Loss: 8.7706, Test Regression Loss: 19.3615
Training Classification Loss: 1.6462
Training RMSE: 8.2018, Validation RMSE: 8.1896, Test RMSE: 17.8233
Training PCC: 0.9712, Validation PCC: 0.9703, Test PCC: 0.7166
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.3807, Validation Total Loss: 7.7193, Test Total Loss: 21.5770
Training Regression Loss: 8.0468, Validation Regression Loss: 7.7193, Test Regression Loss: 21.5770
Training Classification Loss: 1.6695
Training RMSE: 7.6908, Validation RMSE: 7.2523, Test RMSE: 19.7975
Training PCC: 0.9744, Validation PCC: 0.9764, Test PCC: 0.6795
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 7.8661, Validation Total Loss: 8.5684, Test Total Loss: 19.0140
Training Regression Loss: 7.5074, Validation Regression Loss: 8.5684, Test Regression Loss: 19.0140
Training Classification Loss: 1.7934
Training RMSE: 7.1733, Validation RMSE: 7.7847, Test RMSE: 17.4927
Training PCC: 0.9774, Validation PCC: 0.9748, Test PCC: 0.7204
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.6500, Validation Total Loss: 7.1996, Test Total Loss: 21.6464
Training Regression Loss: 7.2757, Validation Regression Loss: 7.1996, Test Regression Loss: 21.6464
Training Classification Loss: 1.8714
Training RMSE: 6.9390, Validation RMSE: 6.8045, Test RMSE: 20.1576
Training PCC: 0.9786, Validation PCC: 0.9795, Test PCC: 0.6656
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.3277, Validation Total Loss: 7.1124, Test Total Loss: 20.2711
Training Regression Loss: 6.9311, Validation Regression Loss: 7.1124, Test Regression Loss: 20.2711
Training Classification Loss: 1.9828
Training RMSE: 6.6235, Validation RMSE: 6.6946, Test RMSE: 18.7286
Training PCC: 0.9805, Validation PCC: 0.9812, Test PCC: 0.7249
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.0682, Validation Total Loss: 6.6690, Test Total Loss: 20.7669
Training Regression Loss: 6.6554, Validation Regression Loss: 6.6690, Test Regression Loss: 20.7669
Training Classification Loss: 2.0643
Training RMSE: 6.3745, Validation RMSE: 6.3614, Test RMSE: 19.2399
Training PCC: 0.9819, Validation PCC: 0.9811, Test PCC: 0.7191
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 6.9226, Validation Total Loss: 6.1860, Test Total Loss: 20.1182
Training Regression Loss: 6.4886, Validation Regression Loss: 6.1860, Test Regression Loss: 20.1182
Training Classification Loss: 2.1698
Training RMSE: 6.2089, Validation RMSE: 5.8392, Test RMSE: 18.2380
Training PCC: 0.9827, Validation PCC: 0.9842, Test PCC: 0.7239
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.9620, Validation Total Loss: 6.7681, Test Total Loss: 19.4529
Training Regression Loss: 6.4956, Validation Regression Loss: 6.7681, Test Regression Loss: 19.4529
Training Classification Loss: 2.3318
Training RMSE: 6.1820, Validation RMSE: 6.3029, Test RMSE: 18.1017
Training PCC: 0.9827, Validation PCC: 0.9829, Test PCC: 0.7131
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.6846, Validation Total Loss: 6.3209, Test Total Loss: 19.9649
Training Regression Loss: 6.2258, Validation Regression Loss: 6.3209, Test Regression Loss: 19.9649
Training Classification Loss: 2.2941
Training RMSE: 5.9552, Validation RMSE: 5.9724, Test RMSE: 18.3486
Training PCC: 0.9838, Validation PCC: 0.9830, Test PCC: 0.7105
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.5461, Validation Total Loss: 5.6667, Test Total Loss: 20.2122
Training Regression Loss: 6.0877, Validation Regression Loss: 5.6667, Test Regression Loss: 20.2122
Training Classification Loss: 2.2918
Training RMSE: 5.8302, Validation RMSE: 5.3964, Test RMSE: 18.6091
Training PCC: 0.9845, Validation PCC: 0.9861, Test PCC: 0.7313
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.2906, Validation Total Loss: 5.7950, Test Total Loss: 19.6710
Training Regression Loss: 5.8308, Validation Regression Loss: 5.7950, Test Regression Loss: 19.6710
Training Classification Loss: 2.2989
Training RMSE: 5.5884, Validation RMSE: 5.4562, Test RMSE: 18.2384
Training PCC: 0.9856, Validation PCC: 0.9860, Test PCC: 0.6988
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.2102, Validation Total Loss: 5.6558, Test Total Loss: 20.4706
Training Regression Loss: 5.7477, Validation Regression Loss: 5.6558, Test Regression Loss: 20.4706
Training Classification Loss: 2.3121
Training RMSE: 5.5195, Validation RMSE: 5.3392, Test RMSE: 18.8733
Training PCC: 0.9859, Validation PCC: 0.9862, Test PCC: 0.7034
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.1068, Validation Total Loss: 5.7101, Test Total Loss: 19.7621
Training Regression Loss: 5.6247, Validation Regression Loss: 5.7101, Test Regression Loss: 19.7621
Training Classification Loss: 2.4101
Training RMSE: 5.3977, Validation RMSE: 5.4147, Test RMSE: 18.2543
Training PCC: 0.9865, Validation PCC: 0.9853, Test PCC: 0.7330
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 5.9608, Validation Total Loss: 6.3962, Test Total Loss: 20.0403
Training Regression Loss: 5.4742, Validation Regression Loss: 6.3962, Test Regression Loss: 20.0403
Training Classification Loss: 2.4331
Training RMSE: 5.2658, Validation RMSE: 5.9296, Test RMSE: 18.8165
Training PCC: 0.9872, Validation PCC: 0.9837, Test PCC: 0.7216
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 5.9920, Validation Total Loss: 5.6384, Test Total Loss: 20.9221
Training Regression Loss: 5.4933, Validation Regression Loss: 5.6384, Test Regression Loss: 20.9221
Training Classification Loss: 2.4937
Training RMSE: 5.2766, Validation RMSE: 5.3655, Test RMSE: 19.3609
Training PCC: 0.9871, Validation PCC: 0.9867, Test PCC: 0.6885
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.8645, Validation Total Loss: 5.4355, Test Total Loss: 20.4392
Training Regression Loss: 5.3582, Validation Regression Loss: 5.4355, Test Regression Loss: 20.4392
Training Classification Loss: 2.5316
Training RMSE: 5.1585, Validation RMSE: 5.1587, Test RMSE: 18.8135
Training PCC: 0.9877, Validation PCC: 0.9872, Test PCC: 0.7154
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.7238, Validation Total Loss: 5.7168, Test Total Loss: 19.6131
Training Regression Loss: 5.2293, Validation Regression Loss: 5.7168, Test Regression Loss: 19.6131
Training Classification Loss: 2.4726
Training RMSE: 5.0403, Validation RMSE: 5.3488, Test RMSE: 18.1621
Training PCC: 0.9883, Validation PCC: 0.9876, Test PCC: 0.7182
Checkpoint saved for epoch 20
Total training time: 2954.47 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.2_lambda_0.5_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_0.3_lambda_0.5_wl100_ol75 with alpha: 0.3
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.4974, Validation Total Loss: 11.9989, Test Total Loss: 22.6528
Training Regression Loss: 18.7725, Validation Regression Loss: 11.9989, Test Regression Loss: 22.6528
Training Classification Loss: 2.4162
Training RMSE: 18.2955, Validation RMSE: 11.6559, Test RMSE: 21.6277
Training PCC: 0.8003, Validation PCC: 0.9426, Test PCC: 0.6204
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.2727, Validation Total Loss: 9.9340, Test Total Loss: 22.3846
Training Regression Loss: 10.6552, Validation Regression Loss: 9.9340, Test Regression Loss: 22.3846
Training Classification Loss: 2.0584
Training RMSE: 10.2473, Validation RMSE: 9.3297, Test RMSE: 20.3499
Training PCC: 0.9538, Validation PCC: 0.9644, Test PCC: 0.6816
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.8583, Validation Total Loss: 8.8306, Test Total Loss: 21.2816
Training Regression Loss: 9.2741, Validation Regression Loss: 8.8306, Test Regression Loss: 21.2816
Training Classification Loss: 1.9474
Training RMSE: 8.8849, Validation RMSE: 8.4940, Test RMSE: 20.1483
Training PCC: 0.9655, Validation PCC: 0.9715, Test PCC: 0.6736
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.2415, Validation Total Loss: 7.9168, Test Total Loss: 20.1562
Training Regression Loss: 8.6509, Validation Regression Loss: 7.9168, Test Regression Loss: 20.1562
Training Classification Loss: 1.9688
Training RMSE: 8.2579, Validation RMSE: 7.4257, Test RMSE: 18.4717
Training PCC: 0.9704, Validation PCC: 0.9752, Test PCC: 0.7234
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.5766, Validation Total Loss: 7.8068, Test Total Loss: 20.7099
Training Regression Loss: 7.9582, Validation Regression Loss: 7.8068, Test Regression Loss: 20.7099
Training Classification Loss: 2.0614
Training RMSE: 7.6086, Validation RMSE: 7.4495, Test RMSE: 19.4060
Training PCC: 0.9745, Validation PCC: 0.9744, Test PCC: 0.6359
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.3347, Validation Total Loss: 7.2321, Test Total Loss: 20.5091
Training Regression Loss: 7.6791, Validation Regression Loss: 7.2321, Test Regression Loss: 20.5091
Training Classification Loss: 2.1852
Training RMSE: 7.3485, Validation RMSE: 6.7882, Test RMSE: 19.3080
Training PCC: 0.9762, Validation PCC: 0.9801, Test PCC: 0.6775
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.9413, Validation Total Loss: 7.2581, Test Total Loss: 19.9308
Training Regression Loss: 7.2653, Validation Regression Loss: 7.2581, Test Regression Loss: 19.9308
Training Classification Loss: 2.2532
Training RMSE: 6.9582, Validation RMSE: 6.7870, Test RMSE: 18.4463
Training PCC: 0.9785, Validation PCC: 0.9801, Test PCC: 0.6600
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.6017, Validation Total Loss: 6.7410, Test Total Loss: 19.9454
Training Regression Loss: 6.9016, Validation Regression Loss: 6.7410, Test Regression Loss: 19.9454
Training Classification Loss: 2.3338
Training RMSE: 6.6088, Validation RMSE: 6.3767, Test RMSE: 18.3371
Training PCC: 0.9803, Validation PCC: 0.9821, Test PCC: 0.6801
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.3888, Validation Total Loss: 6.5781, Test Total Loss: 21.1345
Training Regression Loss: 6.6776, Validation Regression Loss: 6.5781, Test Regression Loss: 21.1345
Training Classification Loss: 2.3707
Training RMSE: 6.3926, Validation RMSE: 6.2229, Test RMSE: 19.3450
Training PCC: 0.9816, Validation PCC: 0.9825, Test PCC: 0.6799
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.2005, Validation Total Loss: 6.2459, Test Total Loss: 20.5183
Training Regression Loss: 6.4713, Validation Regression Loss: 6.2459, Test Regression Loss: 20.5183
Training Classification Loss: 2.4305
Training RMSE: 6.1907, Validation RMSE: 5.8817, Test RMSE: 19.0495
Training PCC: 0.9827, Validation PCC: 0.9839, Test PCC: 0.7037
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.9894, Validation Total Loss: 6.1648, Test Total Loss: 20.3222
Training Regression Loss: 6.2737, Validation Regression Loss: 6.1648, Test Regression Loss: 20.3222
Training Classification Loss: 2.3857
Training RMSE: 6.0130, Validation RMSE: 5.7878, Test RMSE: 18.8481
Training PCC: 0.9835, Validation PCC: 0.9851, Test PCC: 0.6896
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.9654, Validation Total Loss: 6.0444, Test Total Loss: 20.8423
Training Regression Loss: 6.2217, Validation Regression Loss: 6.0444, Test Regression Loss: 20.8423
Training Classification Loss: 2.4791
Training RMSE: 5.9445, Validation RMSE: 5.7394, Test RMSE: 19.4359
Training PCC: 0.9839, Validation PCC: 0.9847, Test PCC: 0.6755
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.7785, Validation Total Loss: 5.8084, Test Total Loss: 19.1412
Training Regression Loss: 6.0198, Validation Regression Loss: 5.8084, Test Regression Loss: 19.1412
Training Classification Loss: 2.5292
Training RMSE: 5.7758, Validation RMSE: 5.4482, Test RMSE: 17.9631
Training PCC: 0.9846, Validation PCC: 0.9860, Test PCC: 0.6865
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.6880, Validation Total Loss: 5.6456, Test Total Loss: 20.9452
Training Regression Loss: 5.9308, Validation Regression Loss: 5.6456, Test Regression Loss: 20.9452
Training Classification Loss: 2.5239
Training RMSE: 5.6857, Validation RMSE: 5.3292, Test RMSE: 19.5890
Training PCC: 0.9853, Validation PCC: 0.9869, Test PCC: 0.6725
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.5134, Validation Total Loss: 5.6185, Test Total Loss: 20.6333
Training Regression Loss: 5.7339, Validation Regression Loss: 5.6185, Test Regression Loss: 20.6333
Training Classification Loss: 2.5984
Training RMSE: 5.4989, Validation RMSE: 5.2928, Test RMSE: 19.2529
Training PCC: 0.9861, Validation PCC: 0.9866, Test PCC: 0.6676
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.4695, Validation Total Loss: 5.6822, Test Total Loss: 21.3140
Training Regression Loss: 5.6931, Validation Regression Loss: 5.6822, Test Regression Loss: 21.3140
Training Classification Loss: 2.5879
Training RMSE: 5.4548, Validation RMSE: 5.3157, Test RMSE: 19.6991
Training PCC: 0.9863, Validation PCC: 0.9872, Test PCC: 0.6963
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 6.4678, Validation Total Loss: 5.7614, Test Total Loss: 21.2983
Training Regression Loss: 5.6693, Validation Regression Loss: 5.7614, Test Regression Loss: 21.2983
Training Classification Loss: 2.6616
Training RMSE: 5.4195, Validation RMSE: 5.4116, Test RMSE: 19.9555
Training PCC: 0.9864, Validation PCC: 0.9873, Test PCC: 0.6859
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 6.2996, Validation Total Loss: 5.3153, Test Total Loss: 20.6331
Training Regression Loss: 5.4879, Validation Regression Loss: 5.3153, Test Regression Loss: 20.6331
Training Classification Loss: 2.7058
Training RMSE: 5.2736, Validation RMSE: 5.0053, Test RMSE: 19.1658
Training PCC: 0.9871, Validation PCC: 0.9883, Test PCC: 0.6821
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 6.2856, Validation Total Loss: 5.1894, Test Total Loss: 19.9974
Training Regression Loss: 5.4109, Validation Regression Loss: 5.1894, Test Regression Loss: 19.9974
Training Classification Loss: 2.9158
Training RMSE: 5.2082, Validation RMSE: 4.9362, Test RMSE: 18.3611
Training PCC: 0.9874, Validation PCC: 0.9891, Test PCC: 0.6956
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 6.0793, Validation Total Loss: 5.8521, Test Total Loss: 21.2990
Training Regression Loss: 5.2638, Validation Regression Loss: 5.8521, Test Regression Loss: 21.2990
Training Classification Loss: 2.7182
Training RMSE: 5.0652, Validation RMSE: 5.4947, Test RMSE: 20.1914
Training PCC: 0.9879, Validation PCC: 0.9880, Test PCC: 0.6699
Checkpoint saved for epoch 20
Total training time: 2967.14 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.3_lambda_0.5_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_0.4_lambda_0.5_wl100_ol75 with alpha: 0.4
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.8737, Validation Total Loss: 11.4263, Test Total Loss: 23.7191
Training Regression Loss: 18.9012, Validation Regression Loss: 11.4263, Test Regression Loss: 23.7191
Training Classification Loss: 2.4311
Training RMSE: 18.3998, Validation RMSE: 11.0143, Test RMSE: 21.9055
Training PCC: 0.7997, Validation PCC: 0.9473, Test PCC: 0.5909
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.7595, Validation Total Loss: 9.4692, Test Total Loss: 21.6289
Training Regression Loss: 10.8737, Validation Regression Loss: 9.4692, Test Regression Loss: 21.6289
Training Classification Loss: 2.2145
Training RMSE: 10.4239, Validation RMSE: 8.9701, Test RMSE: 20.5100
Training PCC: 0.9532, Validation PCC: 0.9676, Test PCC: 0.6428
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.4612, Validation Total Loss: 8.5094, Test Total Loss: 22.0241
Training Regression Loss: 9.5881, Validation Regression Loss: 8.5094, Test Regression Loss: 22.0241
Training Classification Loss: 2.1826
Training RMSE: 9.1617, Validation RMSE: 8.1027, Test RMSE: 21.1266
Training PCC: 0.9644, Validation PCC: 0.9735, Test PCC: 0.6574
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.4869, Validation Total Loss: 7.8763, Test Total Loss: 22.1770
Training Regression Loss: 8.6228, Validation Regression Loss: 7.8763, Test Regression Loss: 22.1770
Training Classification Loss: 2.1603
Training RMSE: 8.2281, Validation RMSE: 7.5199, Test RMSE: 21.2501
Training PCC: 0.9707, Validation PCC: 0.9777, Test PCC: 0.6611
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.9530, Validation Total Loss: 6.9427, Test Total Loss: 21.0568
Training Regression Loss: 8.0730, Validation Regression Loss: 6.9427, Test Regression Loss: 21.0568
Training Classification Loss: 2.2000
Training RMSE: 7.7185, Validation RMSE: 6.6262, Test RMSE: 19.9776
Training PCC: 0.9738, Validation PCC: 0.9804, Test PCC: 0.6441
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.6137, Validation Total Loss: 7.7057, Test Total Loss: 21.7352
Training Regression Loss: 7.6954, Validation Regression Loss: 7.7057, Test Regression Loss: 21.7352
Training Classification Loss: 2.2956
Training RMSE: 7.3535, Validation RMSE: 7.2201, Test RMSE: 20.4441
Training PCC: 0.9762, Validation PCC: 0.9793, Test PCC: 0.6768
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.2874, Validation Total Loss: 6.4218, Test Total Loss: 22.0742
Training Regression Loss: 7.3399, Validation Regression Loss: 6.4218, Test Regression Loss: 22.0742
Training Classification Loss: 2.3687
Training RMSE: 7.0177, Validation RMSE: 6.1371, Test RMSE: 20.7259
Training PCC: 0.9783, Validation PCC: 0.9838, Test PCC: 0.6603
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 8.0509, Validation Total Loss: 6.1823, Test Total Loss: 21.1125
Training Regression Loss: 7.1040, Validation Regression Loss: 6.1823, Test Regression Loss: 21.1125
Training Classification Loss: 2.3674
Training RMSE: 6.7915, Validation RMSE: 5.8803, Test RMSE: 19.8112
Training PCC: 0.9795, Validation PCC: 0.9841, Test PCC: 0.7018
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.7547, Validation Total Loss: 6.2297, Test Total Loss: 21.9579
Training Regression Loss: 6.7705, Validation Regression Loss: 6.2297, Test Regression Loss: 21.9579
Training Classification Loss: 2.4604
Training RMSE: 6.4719, Validation RMSE: 5.9484, Test RMSE: 20.8152
Training PCC: 0.9811, Validation PCC: 0.9847, Test PCC: 0.6838
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.5338, Validation Total Loss: 5.8154, Test Total Loss: 21.0025
Training Regression Loss: 6.5349, Validation Regression Loss: 5.8154, Test Regression Loss: 21.0025
Training Classification Loss: 2.4971
Training RMSE: 6.2505, Validation RMSE: 5.5730, Test RMSE: 19.7834
Training PCC: 0.9822, Validation PCC: 0.9864, Test PCC: 0.7075
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 7.3615, Validation Total Loss: 6.3575, Test Total Loss: 19.8829
Training Regression Loss: 6.3504, Validation Regression Loss: 6.3575, Test Regression Loss: 19.8829
Training Classification Loss: 2.5277
Training RMSE: 6.0826, Validation RMSE: 5.9555, Test RMSE: 18.8716
Training PCC: 0.9832, Validation PCC: 0.9860, Test PCC: 0.7135
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 7.3772, Validation Total Loss: 5.6103, Test Total Loss: 21.8821
Training Regression Loss: 6.3102, Validation Regression Loss: 5.6103, Test Regression Loss: 21.8821
Training Classification Loss: 2.6676
Training RMSE: 6.0370, Validation RMSE: 5.3767, Test RMSE: 20.5780
Training PCC: 0.9835, Validation PCC: 0.9868, Test PCC: 0.6814
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 7.0873, Validation Total Loss: 5.5166, Test Total Loss: 21.9693
Training Regression Loss: 6.0592, Validation Regression Loss: 5.5166, Test Regression Loss: 21.9693
Training Classification Loss: 2.5704
Training RMSE: 5.8005, Validation RMSE: 5.3005, Test RMSE: 20.4849
Training PCC: 0.9846, Validation PCC: 0.9872, Test PCC: 0.6765
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 7.0510, Validation Total Loss: 5.6907, Test Total Loss: 21.1802
Training Regression Loss: 5.9702, Validation Regression Loss: 5.6907, Test Regression Loss: 21.1802
Training Classification Loss: 2.7018
Training RMSE: 5.7239, Validation RMSE: 5.4846, Test RMSE: 19.9484
Training PCC: 0.9850, Validation PCC: 0.9882, Test PCC: 0.7126
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.8463, Validation Total Loss: 5.4504, Test Total Loss: 21.1398
Training Regression Loss: 5.7250, Validation Regression Loss: 5.4504, Test Regression Loss: 21.1398
Training Classification Loss: 2.8033
Training RMSE: 5.4983, Validation RMSE: 5.2103, Test RMSE: 19.6979
Training PCC: 0.9860, Validation PCC: 0.9878, Test PCC: 0.7071
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.9262, Validation Total Loss: 5.3914, Test Total Loss: 19.5014
Training Regression Loss: 5.8030, Validation Regression Loss: 5.3914, Test Regression Loss: 19.5014
Training Classification Loss: 2.8080
Training RMSE: 5.5578, Validation RMSE: 5.1188, Test RMSE: 18.3206
Training PCC: 0.9859, Validation PCC: 0.9885, Test PCC: 0.7127
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 6.7605, Validation Total Loss: 5.9581, Test Total Loss: 19.6603
Training Regression Loss: 5.7075, Validation Regression Loss: 5.9581, Test Regression Loss: 19.6603
Training Classification Loss: 2.6324
Training RMSE: 5.4772, Validation RMSE: 5.5754, Test RMSE: 18.5530
Training PCC: 0.9863, Validation PCC: 0.9874, Test PCC: 0.7257
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 6.6582, Validation Total Loss: 5.0801, Test Total Loss: 19.6176
Training Regression Loss: 5.5796, Validation Regression Loss: 5.0801, Test Regression Loss: 19.6176
Training Classification Loss: 2.6967
Training RMSE: 5.3456, Validation RMSE: 4.8315, Test RMSE: 18.4292
Training PCC: 0.9867, Validation PCC: 0.9896, Test PCC: 0.7336
Checkpoint saved for epoch 18


Epoch 19/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 6.5320, Validation Total Loss: 4.8935, Test Total Loss: 19.5748
Training Regression Loss: 5.4292, Validation Regression Loss: 4.8935, Test Regression Loss: 19.5748
Training Classification Loss: 2.7570
Training RMSE: 5.2154, Validation RMSE: 4.6899, Test RMSE: 18.4759
Training PCC: 0.9874, Validation PCC: 0.9900, Test PCC: 0.6966
Checkpoint saved for epoch 19


Epoch 20/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 6.4799, Validation Total Loss: 5.3486, Test Total Loss: 20.4537
Training Regression Loss: 5.3548, Validation Regression Loss: 5.3486, Test Regression Loss: 20.4537
Training Classification Loss: 2.8127
Training RMSE: 5.1502, Validation RMSE: 5.1327, Test RMSE: 19.0365
Training PCC: 0.9878, Validation PCC: 0.9882, Test PCC: 0.7156
Checkpoint saved for epoch 20
Total training time: 2991.49 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.4_lambda_0.5_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_0.5_lambda_0.5_wl100_ol75 with alpha: 0.5
Starting from scratch.


Epoch 1/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.1191, Validation Total Loss: 12.1867, Test Total Loss: 24.8539
Training Regression Loss: 18.8682, Validation Regression Loss: 12.1867, Test Regression Loss: 24.8539
Training Classification Loss: 2.5020
Training RMSE: 18.3804, Validation RMSE: 11.6103, Test RMSE: 22.6748
Training PCC: 0.7958, Validation PCC: 0.9420, Test PCC: 0.6096
Checkpoint saved for epoch 1


Epoch 2/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.9513, Validation Total Loss: 9.6040, Test Total Loss: 23.1413
Training Regression Loss: 10.8021, Validation Regression Loss: 9.6040, Test Regression Loss: 23.1413
Training Classification Loss: 2.2983
Training RMSE: 10.3821, Validation RMSE: 9.1671, Test RMSE: 21.0311
Training PCC: 0.9539, Validation PCC: 0.9607, Test PCC: 0.6046
Checkpoint saved for epoch 2


Epoch 3/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.6002, Validation Total Loss: 9.1153, Test Total Loss: 21.9270
Training Regression Loss: 9.4942, Validation Regression Loss: 9.1153, Test Regression Loss: 21.9270
Training Classification Loss: 2.2118
Training RMSE: 9.0807, Validation RMSE: 8.6584, Test RMSE: 20.7743
Training PCC: 0.9652, Validation PCC: 0.9672, Test PCC: 0.6246
Checkpoint saved for epoch 3


Epoch 4/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.7131, Validation Total Loss: 8.3710, Test Total Loss: 23.0439
Training Regression Loss: 8.5966, Validation Regression Loss: 8.3710, Test Regression Loss: 23.0439
Training Classification Loss: 2.2331
Training RMSE: 8.2107, Validation RMSE: 7.9757, Test RMSE: 21.1001
Training PCC: 0.9710, Validation PCC: 0.9715, Test PCC: 0.6452
Checkpoint saved for epoch 4


Epoch 5/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.2959, Validation Total Loss: 7.4672, Test Total Loss: 22.5863
Training Regression Loss: 8.1474, Validation Regression Loss: 7.4672, Test Regression Loss: 22.5863
Training Classification Loss: 2.2971
Training RMSE: 7.7759, Validation RMSE: 7.0634, Test RMSE: 20.8960
Training PCC: 0.9744, Validation PCC: 0.9764, Test PCC: 0.6495
Checkpoint saved for epoch 5


Epoch 6/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.7807, Validation Total Loss: 7.7715, Test Total Loss: 21.5275
Training Regression Loss: 7.5894, Validation Regression Loss: 7.7715, Test Regression Loss: 21.5275
Training Classification Loss: 2.3825
Training RMSE: 7.2409, Validation RMSE: 7.3279, Test RMSE: 20.0897
Training PCC: 0.9771, Validation PCC: 0.9770, Test PCC: 0.6663
Checkpoint saved for epoch 6


Epoch 7/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.4878, Validation Total Loss: 6.8726, Test Total Loss: 22.4831
Training Regression Loss: 7.2613, Validation Regression Loss: 6.8726, Test Regression Loss: 22.4831
Training Classification Loss: 2.4530
Training RMSE: 6.9497, Validation RMSE: 6.5092, Test RMSE: 20.5300
Training PCC: 0.9788, Validation PCC: 0.9800, Test PCC: 0.6598
Checkpoint saved for epoch 7


Epoch 8/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 8.3373, Validation Total Loss: 6.4829, Test Total Loss: 22.0129
Training Regression Loss: 7.0481, Validation Regression Loss: 6.4829, Test Regression Loss: 22.0129
Training Classification Loss: 2.5785
Training RMSE: 6.7356, Validation RMSE: 6.1342, Test RMSE: 20.5314
Training PCC: 0.9801, Validation PCC: 0.9816, Test PCC: 0.6395
Checkpoint saved for epoch 8


Epoch 9/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 8.1143, Validation Total Loss: 6.6537, Test Total Loss: 21.5036
Training Regression Loss: 6.8001, Validation Regression Loss: 6.6537, Test Regression Loss: 21.5036
Training Classification Loss: 2.6285
Training RMSE: 6.4966, Validation RMSE: 6.3619, Test RMSE: 20.2358
Training PCC: 0.9813, Validation PCC: 0.9824, Test PCC: 0.6657
Checkpoint saved for epoch 9


Epoch 10/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.8649, Validation Total Loss: 6.4354, Test Total Loss: 21.7026
Training Regression Loss: 6.5340, Validation Regression Loss: 6.4354, Test Regression Loss: 21.7026
Training Classification Loss: 2.6618
Training RMSE: 6.2491, Validation RMSE: 6.1394, Test RMSE: 20.4271
Training PCC: 0.9825, Validation PCC: 0.9829, Test PCC: 0.6659
Checkpoint saved for epoch 10


Epoch 11/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 7.7705, Validation Total Loss: 6.4413, Test Total Loss: 20.5564
Training Regression Loss: 6.3819, Validation Regression Loss: 6.4413, Test Regression Loss: 20.5564
Training Classification Loss: 2.7772
Training RMSE: 6.1044, Validation RMSE: 6.1334, Test RMSE: 19.2304
Training PCC: 0.9832, Validation PCC: 0.9816, Test PCC: 0.6788
Checkpoint saved for epoch 11


Epoch 12/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 7.6741, Validation Total Loss: 6.2405, Test Total Loss: 20.2514
Training Regression Loss: 6.2727, Validation Regression Loss: 6.2405, Test Regression Loss: 20.2514
Training Classification Loss: 2.8028
Training RMSE: 6.0129, Validation RMSE: 5.8764, Test RMSE: 19.0353
Training PCC: 0.9837, Validation PCC: 0.9827, Test PCC: 0.6764
Checkpoint saved for epoch 12


Epoch 13/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 7.7807, Validation Total Loss: 6.3840, Test Total Loss: 22.0633
Training Regression Loss: 6.2156, Validation Regression Loss: 6.3840, Test Regression Loss: 22.0633
Training Classification Loss: 3.1302
Training RMSE: 5.9472, Validation RMSE: 6.0940, Test RMSE: 20.5629
Training PCC: 0.9842, Validation PCC: 0.9829, Test PCC: 0.6526
Checkpoint saved for epoch 13


Epoch 14/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 7.5477, Validation Total Loss: 6.3204, Test Total Loss: 20.0095
Training Regression Loss: 6.0396, Validation Regression Loss: 6.3204, Test Regression Loss: 20.0095
Training Classification Loss: 3.0161
Training RMSE: 5.8057, Validation RMSE: 6.0499, Test RMSE: 18.6209
Training PCC: 0.9847, Validation PCC: 0.9835, Test PCC: 0.6773
Checkpoint saved for epoch 14


Epoch 15/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 7.3235, Validation Total Loss: 5.8865, Test Total Loss: 21.3507
Training Regression Loss: 5.8811, Validation Regression Loss: 5.8865, Test Regression Loss: 21.3507
Training Classification Loss: 2.8848
Training RMSE: 5.6516, Validation RMSE: 5.6391, Test RMSE: 19.9900
Training PCC: 0.9855, Validation PCC: 0.9840, Test PCC: 0.6673
Checkpoint saved for epoch 15


Epoch 16/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 7.2373, Validation Total Loss: 6.1955, Test Total Loss: 20.6976
Training Regression Loss: 5.8229, Validation Regression Loss: 6.1955, Test Regression Loss: 20.6976
Training Classification Loss: 2.8288
Training RMSE: 5.5930, Validation RMSE: 5.7556, Test RMSE: 19.5769
Training PCC: 0.9859, Validation PCC: 0.9859, Test PCC: 0.6613
Checkpoint saved for epoch 16


Epoch 17/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 7.0292, Validation Total Loss: 5.7048, Test Total Loss: 21.4241
Training Regression Loss: 5.6344, Validation Regression Loss: 5.7048, Test Regression Loss: 21.4241
Training Classification Loss: 2.7897
Training RMSE: 5.4238, Validation RMSE: 5.4526, Test RMSE: 20.0082
Training PCC: 0.9868, Validation PCC: 0.9863, Test PCC: 0.6795
Checkpoint saved for epoch 17


Epoch 18/20 Training:   0%|          | 0/327 [00:00<?, ?it/s]

In [None]:
# @title Plotting
import matplotlib.pyplot as plt
def load_most_recent_checkpoint(folder):
    """
    Loads the most recent checkpoint from the given folder and returns it.
    """
    checkpoint_files = [f for f in os.listdir(folder) if f.endswith('.pth')]
    if not checkpoint_files:
        print(f"No checkpoints found in {folder}")
        return None

    # Sort the checkpoint files by epoch (assuming filenames include the epoch number)
    checkpoint_files.sort(key=lambda f: int(f.split('_epoch_')[-1].split('.')[0]))

    # Load the most recent checkpoint
    latest_checkpoint = checkpoint_files[-1]
    checkpoint_path = os.path.join(folder, latest_checkpoint)
    checkpoint = torch.load(checkpoint_path)

    print(f"Loaded checkpoint: {latest_checkpoint} from {folder}")

    # Return the loaded checkpoint, which contains model state, optimizer state, and history
    return checkpoint

def plot_metrics_subplots(aggregated_metrics, folder):
    """Plots train/val/test losses, PCCs, and RMSEs in subplots."""
    epochs = list(range(1, len(aggregated_metrics['train_losses']) + 1))
    line_styles = ['-', '--', '-.', ':']  # Define different line styles for channels

    # New color palette
    train_color = '#17becf'  # Teal blue for Train
    val_color = '#bcbd22'    # Mustard yellow for Val
    test_color = '#e377c2'   # Pastel magenta for Test

    # Create subplots for losses, PCCs, and RMSEs
    fig, axes = plt.subplots(3, 2, figsize=(12, 14))
    fig.suptitle(f'Metrics over Epochs for {folder}', fontsize=16)

    # Plot Losses (subplot 1, 1)
    axes[0, 0].plot(epochs, aggregated_metrics['train_losses'], label='Train Loss', color=train_color, linestyle=line_styles[0])
    axes[0, 0].plot(epochs, aggregated_metrics['val_losses'], label='Val Loss', color=val_color, linestyle=line_styles[0])
    axes[0, 0].plot(epochs, aggregated_metrics['test_losses'], label='Test Loss', color=test_color, linestyle=line_styles[0])
    axes[0, 0].set_title('Losses over Epochs')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()

    # Plot Average PCCs (already averaged, so no need to take mean again) (subplot 1, 2)
    axes[0, 1].plot(epochs, aggregated_metrics['train_pccs'], label='Avg Train PCC', color=train_color, linestyle='-')
    axes[0, 1].plot(epochs, aggregated_metrics['val_pccs'], label='Avg Val PCC', color=val_color, linestyle='-')
    axes[0, 1].plot(epochs, aggregated_metrics['test_pccs'], label='Avg Test PCC', color=test_color, linestyle='-')
    axes[0, 1].set_title('Average PCC over Epochs')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('PCC')
    axes[0, 1].legend()

    # Plot PCCs for each channel (subplot 2, 1)
    for i, style in enumerate(line_styles):
        if len(aggregated_metrics['train_pccs_channelwise'][0]) > i:  # Ensure enough channels exist
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['train_pccs_channelwise']], label=f'Train PCC (Ch {i+1})', color=train_color, linestyle=style)
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['val_pccs_channelwise']], label=f'Val PCC (Ch {i+1})', color=val_color, linestyle=style)
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['test_pccs_channelwise']], label=f'Test PCC (Ch {i+1})', color=test_color, linestyle=style)
    axes[1, 0].set_title('PCCs over Epochs (Per Channel)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('PCC')
    axes[1, 0].legend()

    # Plot Average RMSEs (already averaged, so no need to take mean again) (subplot 2, 2)
    axes[1, 1].plot(epochs, aggregated_metrics['train_rmses'], label='Avg Train RMSE', color=train_color, linestyle='-')
    axes[1, 1].plot(epochs, aggregated_metrics['val_rmses'], label='Avg Val RMSE', color=val_color, linestyle='-')
    axes[1, 1].plot(epochs, aggregated_metrics['test_rmses'], label='Avg Test RMSE', color=test_color, linestyle='-')
    axes[1, 1].set_title('Average RMSE over Epochs')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('RMSE')
    axes[1, 1].legend()

    # Plot RMSEs for each channel (subplot 3, 1)
    for i, style in enumerate(line_styles):
        if len(aggregated_metrics['train_rmses_channelwise'][0]) > i:
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['train_rmses_channelwise']], label=f'Train RMSE (Ch {i+1})', color=train_color, linestyle=style)
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['val_rmses_channelwise']], label=f'Val RMSE (Ch {i+1})', color=val_color, linestyle=style)
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['test_rmses_channelwise']], label=f'Test RMSE (Ch {i+1})', color=test_color, linestyle=style)
    axes[2, 0].set_title('RMSEs over Epochs (Per Channel)')
    axes[2, 0].set_xlabel('Epoch')
    axes[2, 0].set_ylabel('RMSE')
    axes[2, 0].legend()

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to make room for the suptitle
    plt.show()

def plot_comparison_bar_graphs_with_limits(aggregated_metrics_dict):
    """
    Plots two bar graphs comparing the best RMSE and PCC for each model, showing corresponding PCC/RMSE from the same epoch,
    and plots average test RMSEs across epochs for each model.
    """

    # Dictionary to hold the best RMSE and PCC along with their associated metrics at the same epoch
    best_metrics = {}

    # Iterate through the aggregated metrics and extract the best RMSE and its corresponding PCC, and vice versa
    for model_name, metrics in aggregated_metrics_dict.items():
        # Extract pre-averaged RMSEs and PCCs for the test data
        test_rmses = metrics['test_rmses']  # Pre-averaged RMSE values
        test_pccs = metrics['test_pccs']    # Pre-averaged PCC values

        # Find the epoch index for the best RMSE and the corresponding PCC at that epoch
        best_rmse_epoch = np.argmin(test_rmses)
        best_rmse = test_rmses[best_rmse_epoch]
        corresponding_pcc_for_best_rmse = test_pccs[best_rmse_epoch]

        # Find the epoch index for the best PCC and the corresponding RMSE at that epoch
        best_pcc_epoch = np.argmax(test_pccs)
        best_pcc = test_pccs[best_pcc_epoch]
        corresponding_rmse_for_best_pcc = test_rmses[best_pcc_epoch]

        # Store in a dictionary for plotting
        best_metrics[model_name] = {
            'best_rmse': best_rmse,
            'corresponding_pcc_for_best_rmse': corresponding_pcc_for_best_rmse,
            'best_pcc': best_pcc,
            'corresponding_rmse_for_best_pcc': corresponding_rmse_for_best_pcc
        }

    # Convert to lists for sorting
    models = list(best_metrics.keys())
    best_rmses = [best_metrics[model]['best_rmse'] for model in models]
    corresponding_pccs = [best_metrics[model]['corresponding_pcc_for_best_rmse'] for model in models]
    best_pccs = [best_metrics[model]['best_pcc'] for model in models]
    corresponding_rmses = [best_metrics[model]['corresponding_rmse_for_best_pcc'] for model in models]

    # Sort by RMSE and PCC
    sorted_by_rmse = sorted(zip(models, best_rmses, corresponding_pccs), key=lambda x: x[1], reverse=False)
    sorted_by_pcc = sorted(zip(models, best_pccs, corresponding_rmses), key=lambda x: x[1], reverse=True)

    # Unpack sorted values
    sorted_models_rmse, sorted_rmses, sorted_pccs_for_rmse = zip(*sorted_by_rmse)
    sorted_models_pcc, sorted_pccs, sorted_rmses_for_pcc = zip(*sorted_by_pcc)

    # Ensure consistent colors across both graphs
    colors = sns.color_palette("husl", len(models))
    model_colors_rmse = {model: colors[i] for i, model in enumerate(sorted_models_rmse)}
    model_colors_pcc = {model: model_colors_rmse[model] for model in sorted_models_pcc}  # Keep the colors consistent

    # Plot RMSE Bar Graph (With Corresponding PCC values)
    plt.figure(figsize=(10, 6))
    bars_rmse = plt.bar(sorted_models_rmse, sorted_rmses, color=[model_colors_rmse[model] for model in sorted_models_rmse])
    plt.title('Best RMSE per Model (Sorted)')
    plt.xlabel('Model')
    plt.ylabel('Best RMSE')
    plt.ylim(0, 25)  # Set an appropriate y-limit for RMSEs
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Add the actual RMSE values and corresponding PCC on top of the bars
    for i, bar in enumerate(bars_rmse):
        yval_rmse = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval_rmse + 0.5, f'{yval_rmse:.2f}', ha='center', va='bottom')
        plt.text(bar.get_x() + bar.get_width()/2, yval_rmse - 2, f'PCC: {sorted_pccs_for_rmse[i]:.2f}', ha='center', va='top', color='black')

    plt.show()

    # Plot PCC Bar Graph (With Corresponding RMSE values)
    plt.figure(figsize=(10, 6))
    bars_pcc = plt.bar(sorted_models_pcc, sorted_pccs, color=[model_colors_pcc[model] for model in sorted_models_pcc])
    plt.title('Best PCC per Model (Sorted)')
    plt.xlabel('Model')
    plt.ylabel('Best PCC')
    plt.ylim(0, 1)  # Set an appropriate y-limit for PCCs
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Add the actual PCC values and corresponding RMSE on top of the bars
    for i, bar in enumerate(bars_pcc):
        yval_pcc = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval_pcc + 0.02, f'{yval_pcc:.2f}', ha='center', va='bottom')
        plt.text(bar.get_x() + bar.get_width()/2, yval_pcc - 0.08, f'RMSE: {sorted_rmses_for_pcc[i]:.2f}', ha='center', va='top', color='black')

    plt.show()

    # Plot the average test RMSEs across epochs for each model
    plt.figure(figsize=(10, 6))
    for i, model in enumerate(models):
        avg_test_rmses = aggregated_metrics_dict[model]['test_rmses']  # Directly plot pre-averaged RMSEs
        plt.plot(avg_test_rmses, label=f'{model} Average Test RMSE', color=colors[i])

    plt.title('Average Test RMSE per Epoch for All Models')
    plt.xlabel('Epoch')
    plt.ylabel('Average RMSE')
    plt.legend()
    plt.tight_layout()
    plt.show()


# Initialize an empty dictionary to hold metrics for all models
aggregated_metrics_dict = {}

# Loop over each model and load the checkpoint
for model_name, model_config in model_configs.items():
    # Construct the folder path based on the filename
    folder = f"/content/MyDrive/MyDrive/models/{model_name}"
    checkpoint = load_most_recent_checkpoint(folder)

    if checkpoint:
        aggregated_metrics = checkpoint['history']

        # Store the metrics for comparison later
        aggregated_metrics_dict[model_name] = {
            'test_rmses': aggregated_metrics['test_rmses'],
            'test_pccs': aggregated_metrics['test_pccs'],
        }

        # Plot metrics for this model
        plot_metrics_subplots(aggregated_metrics, model_name)

# After collecting all model metrics, plot the comparison bar graphs
if aggregated_metrics_dict:
    plot_comparison_bar_graphs_with_limits(aggregated_metrics_dict)
