In [14]:

#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")



Drive already mounted at /content/MyDrive; to attempt to forcibly remount, call drive.mount("/content/MyDrive", force_remount=True).


In [15]:
# @title Initialize Config

import torch
import numpy
class Config:
    def __init__(self, **kwargs):
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection_v4/all_subjects_data.h5',
    dataset_root='/content/datasets',
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1','ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
)

#set seeds
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)


In [16]:
class DataSharder:
    def __init__(self, config, split):
        self.config = config
        self.h5_file_path = config.data_folder_name  # Path to the HDF5 file
        self.split = split

    def load_data(self, subjects, window_length, window_overlap, dataset_name):
        print(f"Processing subjects: {subjects} with window length: {window_length}, overlap: {window_overlap}")

        self.window_length = window_length
        self.window_overlap = window_overlap

        # Process the data from the HDF5 file
        self._process_and_save_patients_h5(subjects, dataset_name)

    def _process_and_save_patients_h5(self, subjects, dataset_name):
        # Open the HDF5 file
        with h5py.File(self.h5_file_path, 'r') as h5_file:
            dataset_folder = os.path.join(self.config.dataset_root, dataset_name, self.split).replace("subject", "").replace("__", "_")
            print("Dataset folder:", dataset_folder)

            if os.path.exists(dataset_folder):
                print("Dataset Exists, Skipping...")
                return

            os.makedirs(dataset_folder, exist_ok=True)
            print("Dataset folder created: ", dataset_folder)

            for subject_id in tqdm(subjects, desc="Processing subjects"):
                subject_key = subject_id
                if subject_key not in h5_file:
                    print(f"Subject {subject_key} not found in the HDF5 file. Skipping.")
                    continue

                subject_data = h5_file[subject_key]
                session_keys = list(subject_data.keys())  # Sessions for this subject

                for session_id in session_keys:
                    session_data_group = subject_data[session_id]

                    for sessions_speed in session_data_group.keys():
                        session_data = session_data_group[sessions_speed]

                        # Extract IMU, EMG, and Joint data as numpy arrays
                        imu_data, imu_columns = self._extract_channel_data(session_data, self.config.channels_imu_acc + self.config.channels_imu_gyr)
                        emg_data, emg_columns = self._extract_channel_data(session_data, self.config.channels_emg)
                        joint_data, joint_columns = self._extract_channel_data(session_data, self.config.channels_joints)

                        # Shard the data into windows and save each window
                        self._save_windowed_data(imu_data, emg_data, joint_data, subject_key, session_id,sessions_speed, dataset_folder, imu_columns, emg_columns, joint_columns)

    def _save_windowed_data(self, imu_data, emg_data, joint_data, subject_key, session_id, session_speed, dataset_folder, imu_columns, emg_columns, joint_columns):
        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        # Path to the CSV log file
        csv_file_path = os.path.join(dataset_folder, '..', f"{self.split}_info.csv")

        # Ensure the folder exists
        os.makedirs(dataset_folder, exist_ok=True)

        # Prepare CSV log headers (ensure the columns are 'file_name' and 'file_path')
        csv_headers = ['file_name', 'file_path']

        # Create or append to the CSV log file
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Write the headers only if the file is new
            if not file_exists:
                writer.writerow(csv_headers)

            # Determine the total data length based on the minimum length across the data sources
            total_data_length = min(imu_data.shape[1], emg_data.shape[1], joint_data.shape[1])

            # Adjust the starting point for windows based on total data length
            start = 2000 if total_data_length > 4000 else 0

            # Ensure that each window across imu_data, emg_data, and joint_data has the same shape before concatenation
            for i in range(start, total_data_length - window_size + 1, step_size):
                imu_window = imu_data[:, i:i + window_size]
                emg_window = emg_data[:, i:i + window_size]
                joint_window = joint_data[:, i:i + window_size]

                # Check if the window sizes are valid
                if imu_window.shape[1] == window_size and emg_window.shape[1] == window_size and joint_window.shape[1] == window_size:
                    # Convert windowed data to pandas DataFrame



                    imu_df = pd.DataFrame(imu_window.T, columns=imu_columns)
                    emg_df = pd.DataFrame(emg_window.T, columns=emg_columns)
                    joint_df = pd.DataFrame(joint_window.T, columns=joint_columns)



                    # Concatenate the data along the column axis
                    combined_df = pd.concat([imu_df, emg_df, joint_df], axis=1)

                    # Save the combined windowed data as a CSV file
                    file_name = f"{subject_key}_{session_id}_{session_speed}_win_{i}_ws{window_size}_ol{overlap}.csv"
                    file_path = os.path.join(dataset_folder, file_name)
                    combined_df.to_csv(file_path, index=False)

                    # Log the file name and path in the CSV (in the correct columns)
                    writer.writerow([file_name, file_path])
                else:
                    print(f"Skipping window {i} due to mismatched window sizes.")

    def _extract_channel_data(self, session_data, channels):
        """
        Extracts data for the given channels from the dataset (whether it's a compound dataset or simple dataset),
        and interpolates missing values (NaNs) in each channel data.
        """
        extracted_data = []
        column_names = []

        if isinstance(session_data, h5py.Dataset):
            # Check if the dataset has named fields (compound dataset)
            if session_data.dtype.names:
                # Compound dataset, use the named fields
                column_names = session_data.dtype.names
                for channel in channels:
                    if channel in column_names:
                        channel_data = session_data[channel][:]  # Access by field name
                        # Convert the data to a numeric type (float), if necessary
                        channel_data = pd.to_numeric(channel_data, errors='coerce')
                        # Interpolate NaN values
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                    else:
                        print(f"Channel {channel} not found in compound dataset.")
            else:
                # Simple dataset, use index-based access (no named fields)
                column_names = session_data.attrs.get('column_names', [])

                # Cast column_names to a list to allow 'index' lookup
                column_names = list(column_names)
                new_column_names = []

                assert len(column_names) > 0, "column_names not found in dataset attributes"
                for channel in channels:
                    if channel in column_names:
                        col_idx = column_names.index(channel)
                        new_column_names.append(channel)
                        channel_data = session_data[:, col_idx]  # Access by column index

                        # Convert the data to a numeric type (float), if necessary
                        channel_data = pd.to_numeric(channel_data, errors='coerce')

                        # Interpolate NaN values
                        df = pd.DataFrame(channel_data)
                        df_interpolated = df.interpolate(method='linear', axis=0, limit_direction='both')
                        extracted_data.append(df_interpolated.to_numpy().flatten())
                    else:
                        print(f"Channel {channel} not found in session data.")

        return np.array(extracted_data), new_column_names

In [17]:
# @title Dataset creation
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.data import ConcatDataset
import random
from torch.utils.data import TensorDataset

class ImuJointPairDataset(Dataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        self.config = config
        self.split = split
        self.subjects = subjects
        self.window_length = window_length
        self.window_overlap = window_overlap if split == 'train' else 0
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        # Convert the list of subjects to a string that is path-safe
        subjects_str = "_".join(map(str, subjects)).replace('subject', '').replace('__', '_')

        # Use dataset_train_name or dataset_test_name based on split
        if split == 'train':
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{subjects_str}"
        else:
            dataset_name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_test{subjects_str}"

        self.dataset_name = dataset_name

        # Define the root directory based on dataset name
        self.root_dir = os.path.join(self.config.dataset_root, self.dataset_name)

        # Ensure sharded data exists, if not, reshard
        self.ensure_resharded(subjects, dataset_train_name if split == 'train' else dataset_test_name)

        info_path = os.path.join(self.root_dir, f"{split}_info.csv")
        self.data = pd.read_csv(info_path)

    def ensure_resharded(self, subjects, dataset_name):
        if not os.path.exists(self.root_dir):
            print(f"Sharded data not found at {self.root_dir}. Resharding...")
            data_sharder = DataSharder(self.config,self.split)
            # Pass dynamic parameters to sharder
            data_sharder.load_data(subjects, window_length=self.window_length, window_overlap=self.window_overlap, dataset_name=self.dataset_name)
        else:
            print(f"Sharded data found at {self.root_dir}. Skipping resharding.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir,self.split, self.data.iloc[idx, 0])

        if self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        return torch.tensor(data, dtype=torch.float32)

class ImuJointPairSubjectDataset(ImuJointPairDataset):
    def __init__(self, config, subjects, window_length, window_overlap, split='train', dataset_train_name='train', dataset_test_name='test'):
        super().__init__(config, subjects, window_length, window_overlap, split, dataset_train_name, dataset_test_name)

        # Create a mapping from subject strings (e.g., 'subject_1') to class indices
        self.subject_mapping = {subject: i for i, subject in enumerate(sorted(subjects))}

    def __getitem__(self, idx):
        # Retrieve the original data from the parent class
        imu_data_acc, imu_data_gyr, joint_data, emg_data = super().__getitem__(idx)

        # Get the filename from the data index
        filename = self.data.iloc[idx, 0]

        # Extract the subject ID from the filename
        filename_base = os.path.basename(filename)
        filename_without_ext = os.path.splitext(filename_base)[0]
        parts = filename_without_ext.split('_')

        # Construct subject string in the format 'subject_x'
        try:
            subject_index = parts.index('subject')
            subject_str = f"subject_{parts[subject_index + 1]}"  # Create string like 'subject_1'
        except ValueError:
            raise ValueError(f"'subject' not found in filename: {filename}")
        except IndexError:
            raise ValueError(f"Subject ID not found after 'subject' in filename: {filename}")

        # Map subject_str to class index
        if subject_str not in self.subject_mapping:
            raise ValueError(f"Subject ID {subject_str} not found in training set.")

        mapped_class = self.subject_mapping[subject_str]

        # Return class index instead of one-hot encoding
        return imu_data_acc, imu_data_gyr, joint_data, emg_data, mapped_class


def create_base_data_loaders(
    config,
    train_subjects,
    test_subjects,
    window_length=100,
    window_overlap=75,
    batch_size=64,
    dataset_train_name='train',
    dataset_test_name='test'
):
    # Create datasets with explicit parameters
    train_dataset = ImuJointPairSubjectDataset(
        config=config,
        subjects=train_subjects,
        window_length=window_length,
        window_overlap=window_overlap,
        split='train',
        dataset_train_name=dataset_train_name
    )

    test_dataset = ImuJointPairSubjectDataset(
        config=config,
        subjects=test_subjects,
        window_length=window_length,
        window_overlap=0,  # Typically no overlap in test set
        split='test',
        dataset_test_name=dataset_test_name
    )

    # Split train dataset into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader




In [18]:
# @title Kinematicsnet Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from scipy.signal import butter, filtfilt
from sklearn.metrics import mean_squared_error
import numpy as np
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, (h_1, _) = self.lstm_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, (h_2, _) = self.lstm_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)

class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.gru_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0)
        self.gru_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        out_1, h_1 = self.gru_1(x)
        out_1 = self.dropout_1(out_1)
        out_2, h_2 = self.gru_2(out_1)
        out_2 = self.dropout_2(out_2)
        return out_2, (h_1, h_2)


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used


from torch.autograd import Function

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_grl):
        ctx.lambda_grl = lambda_grl
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        lambda_grl = ctx.lambda_grl
        grad_input = grad_output.neg() * lambda_grl
        return grad_input, None

def GradientReversalLayer(lambda_grl):
    return GradientReversalFunction.apply

class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100,num_subjects=12,lambda_grl=1.0):
        super(teacher, self).__init__()

        self.lambda_grl = lambda_grl
        self.w=w
        self.encoder_1_acc=Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr=Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg=Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc=Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr=Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg=Encoder_2(input_emg, drop_prob)

        self.BN_acc= nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr= nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg= nn.BatchNorm1d(input_emg, affine=False)


        self.fc = nn.Linear(2*3*128+128,3)
        self.fc_classifier = nn.Linear(2*3*128+128,num_subjects)
        self.dropout=nn.Dropout(p=.05)

        self.gate_1=GatingModule(128)
        self.gate_2=GatingModule(128)
        self.gate_3=GatingModule(128)

        self.fc_kd = nn.Linear(3*128, 2*128)

               # Define the gating network
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention=nn.MultiheadAttention(3*128,4,batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128*3, 3*128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2*3*128+128, 2*3*128+128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)


    def forward(self, x_acc, x_gyr, x_emg):

        x_acc_1=x_acc.view(x_acc.size(0)*x_acc.size(1),x_acc.size(-1))
        x_gyr_1=x_gyr.view(x_gyr.size(0)*x_gyr.size(1),x_gyr.size(-1))
        x_emg_1=x_emg.view(x_emg.size(0)*x_emg.size(1),x_emg.size(-1))

        x_acc_1=self.BN_acc(x_acc_1)
        x_gyr_1=self.BN_gyr(x_gyr_1)
        x_emg_1=self.BN_emg(x_emg_1)

        x_acc_2=x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2=x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2=x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        # Pass through Encoder 1 for each modality and capture hidden states
        x_acc_1, (h_acc_1, _) = self.encoder_1_acc(x_acc_2)
        x_gyr_1, (h_gyr_1, _) = self.encoder_1_gyr(x_gyr_2)
        x_emg_1, (h_emg_1, _) = self.encoder_1_emg(x_emg_2)

        # Pass through Encoder 2 for each modality and capture hidden states
        x_acc_2, (h_acc_2, _) = self.encoder_2_acc(x_acc_2)
        x_gyr_2, (h_gyr_2, _) = self.encoder_2_gyr(x_gyr_2)
        x_emg_2, (h_emg_2, _) = self.encoder_2_emg(x_emg_2)

        # x_acc=torch.cat((x_acc_1,x_acc_2),dim=-1)
        # x_gyr=torch.cat((x_gyr_1,x_gyr_2),dim=-1)
        # x_emg=torch.cat((x_emg_1,x_emg_2),dim=-1)

        x_acc=self.gate_1(x_acc_1,x_acc_2)
        x_gyr=self.gate_2(x_gyr_1,x_gyr_2)
        x_emg=self.gate_3(x_emg_1,x_emg_2)

        x=torch.cat((x_acc,x_gyr,x_emg),dim=-1)
        x_kd=self.fc_kd(x)


        out_1, attn_output_weights=self.attention(x,x,x)

        gating_weights = self.gating_net(x)
        out_2=gating_weights*x

        weights_1 = self.weighted_feat(x[:,:,0:128])
        weights_2 = self.weighted_feat(x[:,:,128:2*128])
        weights_3 = self.weighted_feat(x[:,:,2*128:3*128])
        x_1=weights_1*x[:,:,0:128]
        x_2=weights_2*x[:,:,128:2*128]
        x_3=weights_3*x[:,:,2*128:3*128]
        out_3=x_1+x_2+x_3

        out=torch.cat((out_1,out_2,out_3),dim=-1)

        gating_weights_1 = self.gating_net_1(out)
        out=gating_weights_1*out

        #print shape of out
        out_task=self.fc(out)

        grl = GradientReversalLayer(self.lambda_grl)
        out_rev = grl(out, self.lambda_grl)  # Apply GRL to reverse gradients

        # Domain classification with reversed gradients
        out_rev = out_rev.permute(0, 2, 1)  # Permute to [batch_size, features, time]
        out_rev = torch.nn.functional.adaptive_avg_pool1d(out_rev, 1)  # Pooling
        out_rev = out_rev.squeeze(-1)  # Shape: [batch_size, features]

        out_classifier = self.fc_classifier(out_rev)  # Domain classifier output


        #print(out.shape)
        return out_task, out_classifier, (h_acc_1, h_acc_2, h_gyr_1, h_gyr_2, h_emg_1, h_emg_2)




In [19]:
# @title Loss Functions
import statistics

class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss

#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7

def compute_biomechanical_loss(predicted_angles):
    """
    Compute the biomechanical loss (L_bio) to enforce joint limits
    for the three joints, using index-based access.
    """
    # Define joint limits for each channel (in degrees)
    min_limits = torch.tensor([0, -90, -120], device=predicted_angles.device)
    max_limits = torch.tensor([150, 180, 90], device=predicted_angles.device)

    while min_limits.dim() < predicted_angles.dim():
        min_limits = min_limits.unsqueeze(0)
    while max_limits.dim() < predicted_angles.dim():
        max_limits = max_limits.unsqueeze(0)

    # Now min_limits and max_limits have shape [1, 1, num_joints] and will broadcast correctly
    lower_violation = torch.relu(min_limits - predicted_angles)
    upper_violation = torch.relu(predicted_angles - max_limits)

    L_bio = torch.mean(lower_violation + upper_violation)

    return L_bio


class BoundRmseLoss(nn.Module):
    def __init__(self, lambda_bio=0.1):
        super(BoundRmseLoss, self).__init__()
        self.lambda_bio = lambda_bio
        self.rmse_loss = RMSELoss()  # Using your existing RMSELoss class

    def forward(self, output, target):
        # Compute RMSE loss
        L_data = self.rmse_loss(output, target)

        # Compute biomechanical loss
        L_bio = compute_biomechanical_loss(output)

        # Total loss
        total_loss = L_data + self.lambda_bio * L_bio

        return total_loss

In [20]:
# @title Model Utils

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG,_) in enumerate(loader):
            output= model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())

            if isinstance(model, teacher):
                output,knowledge_distillation,_ = output
                loss = criterion(output, target.to(device).float())
            else:
                loss = criterion(output, target.to(device).float())

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse



def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None,
                    channelwise_metrics=None, history=None, curriculum_schedule=None):

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }

    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']

    # Save the history (losses, PCCs, RMSEs, channel-wise metrics)
    if history:
        checkpoint['history'] = history

    # Save curriculum schedule
    if curriculum_schedule:
        checkpoint['curriculum_schedule'] = curriculum_schedule

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")



def train_teacher(
    device,
    train_loader,
    val_loader,
    test_loader,
    learn_rate,
    epochs,
    model,
    filename,
    loss_function,
    optimizer=None,
    l1_lambda=None,
    train_from_last_epoch=False,
    curriculum_loader=None,
    num_classes=12,  # Add num_classes parameter
    alpha=1.0,       # Weighting factor for classification loss
):
    model.to(device)
    criterion_regression = loss_function
    criterion_classification = nn.CrossEntropyLoss(ignore_index=-1)

    if optimizer is None:
        # Create a default Adam optimizer if none is passed
        optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    train_pccs_channelwise = []
    val_pccs_channelwise = []
    test_pccs_channelwise = []

    train_rmses_channelwise = []
    val_rmses_channelwise = []
    test_rmses_channelwise = []

    # Check for existing checkpoint to resume training
    last_epoch = 0
    checkpoint_path = f"/content/MyDrive/MyDrive/models/{filename}/"

    if train_from_last_epoch and os.path.exists(checkpoint_path):
        # Scan for the latest saved checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.pth')]
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch number
            latest_checkpoint = checkpoints[-1]
            print(f"Loading model from checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(os.path.join(checkpoint_path, latest_checkpoint))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            last_epoch = checkpoint['epoch']  # Continue from the next epoch

            # Load the history from checkpoint
            train_losses = checkpoint['history']['train_losses']
            val_losses = checkpoint['history']['val_losses']
            test_losses = checkpoint['history']['test_losses']
            train_pccs = checkpoint['history']['train_pccs']
            val_pccs = checkpoint['history']['val_pccs']
            test_pccs = checkpoint['history']['test_pccs']
            train_rmses = checkpoint['history']['train_rmses']
            val_rmses = checkpoint['history']['val_rmses']
            test_rmses = checkpoint['history']['test_rmses']
            train_pccs_channelwise = checkpoint['history']['train_pccs_channelwise']
            val_pccs_channelwise = checkpoint['history']['val_pccs_channelwise']
            test_pccs_channelwise = checkpoint['history']['test_pccs_channelwise']
            train_rmses_channelwise = checkpoint['history']['train_rmses_channelwise']
            val_rmses_channelwise = checkpoint['history']['val_rmses_channelwise']
            test_rmses_channelwise = checkpoint['history']['test_rmses_channelwise']
            if 'curriculum_schedule' in checkpoint:
                curriculum_loader.curriculum_schedule = checkpoint['curriculum_schedule']  # Load saved curriculum schedule
        else:
            print("No checkpoints found, starting from scratch.")
    else:
        print("Starting from scratch.")

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    for epoch in range(last_epoch, epochs):
        epoch_start_time = time.time()
        model.train()

        if curriculum_loader:
            curriculum_loader.update_epoch(epoch)
            train_loader, val_loader, test_loader = curriculum_loader.get_loaders()

        # Track total loss
        epoch_train_total_loss = 0.0
        epoch_train_regression_loss = 0.0
        epoch_train_classification_loss = 0.0

        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        for i, (data_acc, data_gyr, target, data_EMG, subject_id) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()

            #print dimensions of each and subject id
            # print("acc",data_acc.shape)
            # print("gyro",data_gyr.shape)
            # print("target",target.shape)
            # print("emg",data_EMG.shape)
            # print("subject ids",subject_id)

            # Move data to device
            data_acc = data_acc.to(device).float()
            data_gyr = data_gyr.to(device).float()
            data_EMG = data_EMG.to(device).float()
            target = target.to(device).float()
            subject_id = subject_id.to(device).long() - 1  # Adjust subject IDs to start from 0

            # Forward pass
            output, out_classifier, _ = model(data_acc, data_gyr, data_EMG)

            # print(f"Output shape: {output.shape}")
            # print(f"Target shape: {target.shape}")
            # Compute regression loss
            loss_regression = criterion_regression(output, target)

            # Compute classification loss
            # Ensure shapes are correct
            # print(f"Out Classifier shape: {out_classifier.shape}")
            # print(f"Subject ID shape: {subject_id.shape}")
            # print(f"Subject ID dtype: {subject_id.dtype}")
            # print(f"Out Classifier dtype: {out_classifier.dtype}")
            # print(f"subject ids",subject_id)
            # print(f"out classifier",out_classifier)
            loss_classification = criterion_classification(out_classifier, subject_id)

            # Combine losses
            total_loss = loss_regression + alpha * loss_classification

            # Apply L1 regularization if specified
            if l1_lambda is not None:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                total_loss += l1_lambda * l1_norm

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()



            # Detach tensors and move to CPU to prevent issues with gradient computation
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(
                output.detach().cpu().numpy(),
                target.detach().cpu().numpy(),
                len(config.channels_joints),
                print_losses=False
            )

            # Update metrics
            epoch_train_total_loss += total_loss.item()
            epoch_train_regression_loss += loss_regression.item()
            epoch_train_classification_loss += loss_classification.item()

            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        # Calculate average losses
        avg_train_total_loss = epoch_train_total_loss / len(train_loader)
        avg_train_regression_loss = epoch_train_regression_loss / len(train_loader)
        avg_train_classification_loss = epoch_train_classification_loss / len(train_loader)

        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_total_loss)
        train_pccs.append(np.mean(avg_train_pcc))  # Overall average PCC
        train_rmses.append(np.mean(avg_train_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        train_pccs_channelwise.append(avg_train_pcc)  # Per channel
        train_rmses_channelwise.append(avg_train_rmse)  # Per channel

        # Evaluate on validation set
        avg_val_total_loss, avg_val_pcc, avg_val_rmse = evaluate_model(
            device,
            model,
            val_loader,
            criterion_regression
        )

        val_losses.append(avg_val_total_loss)
        val_pccs.append(np.mean(avg_val_pcc))  # Overall average PCC
        val_rmses.append(np.mean(avg_val_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        val_pccs_channelwise.append(avg_val_pcc)  # Per channel
        val_rmses_channelwise.append(avg_val_rmse)  # Per channel

        # Evaluate on test set
        avg_test_total_loss, avg_test_pcc, avg_test_rmse = evaluate_model(
            device,
            model,
            test_loader,
            criterion_regression
        )

        test_losses.append(avg_test_total_loss)
        test_pccs.append(np.mean(avg_test_pcc))  # Overall average PCC
        test_rmses.append(np.mean(avg_test_rmse))  # Overall average RMSE

        # Save channel-wise metrics
        test_pccs_channelwise.append(avg_test_pcc)  # Per channel
        test_rmses_channelwise.append(avg_test_rmse)  # Per channel

        print(f"Epoch: {epoch + 1}, Training Total Loss: {avg_train_total_loss:.4f}, Validation Total Loss: {avg_val_total_loss:.4f}, Test Total Loss: {avg_test_total_loss:.4f}")
        print(f"Training Regression Loss: {avg_train_regression_loss:.4f}, Validation Regression Loss: {avg_val_total_loss:.4f}, Test Regression Loss: {avg_test_total_loss:.4f}")
        print(f"Training Classification Loss: {avg_train_classification_loss:.4f}, Validation Classification Loss: {avg_val_total_loss:.4f}, Test Classification Loss: {avg_test_total_loss:.4f}")
        print(f"Training RMSE: {np.mean(avg_train_rmse):.4f}, Validation RMSE: {np.mean(avg_val_rmse):.4f}, Test RMSE: {np.mean(avg_test_rmse):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc):.4f}, Validation PCC: {np.mean(avg_val_pcc):.4f}, Test PCC: {np.mean(avg_test_pcc):.4f}")

        # Save checkpoint, including curriculum schedule
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save checkpoint with the curriculum schedule
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'test_losses': test_losses,
            'train_pccs': train_pccs,
            'val_pccs': val_pccs,
            'test_pccs': test_pccs,
            'train_rmses': train_rmses,
            'val_rmses': val_rmses,
            'test_rmses': test_rmses,
            'train_pccs_channelwise': train_pccs_channelwise,
            'val_pccs_channelwise': val_pccs_channelwise,
            'test_pccs_channelwise': test_pccs_channelwise,
            'train_rmses_channelwise': train_rmses_channelwise,
            'val_rmses_channelwise': val_rmses_channelwise,
            'test_rmses_channelwise': test_rmses_channelwise
        }

        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"{checkpoint_path}/{filename}_epoch_{epoch + 1}.pth",
            train_loss=avg_train_total_loss,
            val_loss=avg_val_total_loss,
            test_loss=avg_test_total_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            },
            history=history,  # Save history in the checkpoint
            curriculum_schedule=curriculum_loader.curriculum_schedule if curriculum_loader else None  # Save curriculum schedule
        )

        # Early stopping logic
        if avg_val_total_loss < best_val_loss:
            best_val_loss = avg_val_total_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses







In [21]:
# @title Helper Functions


# Function to create the teacher model with defaults from config
def create_teacher_model(input_acc, input_gyr, input_emg, base_weights_path=None, drop_prob=0.25, w=100,num_subjects=None):
    model = teacher(input_acc, input_gyr, input_emg, drop_prob=drop_prob, w=w,num_subjects=num_subjects)

    if base_weights_path:
        # Load the initial weights from the base model
        model.load_state_dict(torch.load(base_weights_path))

    return model


In [22]:

import os
import h5py
import csv
from tqdm.notebook import tqdm
import pandas as pd
def create_curriculum_schedule(all_subjects, num_subjects, epochs):
    return [(epoch, sorted(random.sample(all_subjects, num_subjects))) for epoch in range(epochs)]

# Define the list of all subjects
all_subjects = ['subject_1','subject_2', 'subject_3', 'subject_4', 'subject_5', 'subject_6',
                'subject_7', 'subject_8', 'subject_9', 'subject_10',
                'subject_11', 'subject_12', 'subject_13']

# Define the base model and save its weights
input_acc, input_gyr, input_emg = 18, 18, 3  # Example inputs based on your setup
base_model = teacher(input_acc, input_gyr, input_emg)
base_weights_path = 'base_teacher_weights.pth'
torch.save(base_model.state_dict(), base_weights_path)
window_overlap = 0
batch_size = 64
curriculum_epochs = 50

# Create model configurations for each subject as the test subject
model_configs = {}

test_subject = 'subject_1'
train_subjects = [subject for subject in all_subjects if subject not in test_subject]

# Define the number of subjects/classes
num_classes = 12

# List of different weights for the classification loss
alpha_values = [.2,.4,.8,1,2,4,8]


for alpha in alpha_values:
    model_name = f'TeacherModel_DomainInvariant_alpha_{alpha}_wl{100}_ol{75}'
    model_configs[model_name] = {
        'model': create_teacher_model(
            input_acc=input_acc,
            input_gyr=input_gyr,
            input_emg=input_emg,
            w=100,
            num_subjects=num_classes  # Ensure the model includes the classification head
        ),
        'loss': RMSELoss(),
        'loaders': create_base_data_loaders(
            config=config,
            train_subjects=train_subjects,
            test_subjects=[test_subject],
            window_length=100,
            window_overlap=75,
            batch_size=batch_size
        ),
        'epochs': curriculum_epochs,
        'use_curriculum': False,
        'alpha': alpha,  # Weighting factor for classification loss
        'num_classes': num_classes,  # Number of classes for the classification task
    }

# Output the model configurations for debugging (optional)
for model_name, model_conf in model_configs.items():
    print(f"Model: {model_name}")


Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol75_train_2_3_4_5_6_7_8_9_10_11_12_13. Skipping resharding.
Sharded data found at /content/datasets/dataset_wl100_ol0_test_1. Skipping re

In [23]:
 # @title run models

#clear gpu memory
torch.cuda.empty_cache()

def ask_run():
    response = input("Do you want to run models? (yes/no): ").strip().lower()
    if response in ['yes', 'y']:
        return True
    elif response in ['no', 'n']:
        return False
    else:
        print("Invalid input. Please enter 'yes' or 'no'.")
        return ask_run()  # Recursively ask again until valid input is given

run = ask_run()

if run:
    for model_name, model_config in model_configs.items():
        model = model_config['model']
        loss_function = model_config['loss']  # May be None, as loss functions are defined within train_teacher
        epochs = model_config.get("epochs", 100)
        device = model_config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        learn_rate = model_config.get("learn_rate", 0.001)
        use_curriculum = model_config.get("use_curriculum", False)

        optimizer = model_config.get("optimizer", None)
        l1_lambda = model_config.get("l1_lambda", None)
        alpha = model_config.get('alpha', 1.0)
        num_classes = model_config.get('num_classes', 12)  # Default to 12 if not specified

        print(f"Running model: {model_name} with alpha: {alpha}")

        if use_curriculum:
            curriculum_loader = model_config['loader']  # Get the CurriculumDataLoader
            train_loader, val_loader, test_loader = None, None, None  # Curriculum will handle loading per epoch

            model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, \
            train_rmses, val_rmses, test_rmses = train_teacher(
                device=device,
                train_loader=train_loader,  # Placeholders; curriculum will manage loaders dynamically
                val_loader=val_loader,
                test_loader=test_loader,
                learn_rate=learn_rate,
                epochs=epochs,
                model=model,
                filename=model_name,
                loss_function=loss_function,
                curriculum_loader=curriculum_loader,  # Pass the curriculum loader here
                optimizer=optimizer,
                l1_lambda=l1_lambda,
                train_from_last_epoch=model_config.get("train_from_last_epoch", False),
                alpha=alpha,
                num_classes=num_classes,
            )
        else:
            # Unpack the static loaders tuple (train_loader, val_loader, test_loader)
            train_loader, val_loader, test_loader = model_config['loaders']

            model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, \
            train_rmses, val_rmses, test_rmses = train_teacher(
                device=device,
                train_loader=train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                learn_rate=learn_rate,
                epochs=epochs,
                model=model,
                filename=model_name,
                loss_function=loss_function,
                optimizer=optimizer,
                l1_lambda=l1_lambda,
                train_from_last_epoch=model_config.get("train_from_last_epoch", False),
                alpha=alpha,
                num_classes=num_classes,
            )

        print(f"Finished training for {model_name}.")


Do you want to run models? (yes/no): yes
Running model: TeacherModel_DomainInvariant_alpha_0.2_wl100_ol75 with alpha: 0.2
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 19.3082, Validation Total Loss: 11.3202, Test Total Loss: 22.5309
Training Regression Loss: 18.8219, Validation Regression Loss: 11.3202, Test Regression Loss: 22.5309
Training Classification Loss: 2.4314, Validation Classification Loss: 11.3202, Test Classification Loss: 22.5309
Training RMSE: 18.3245, Validation RMSE: 10.9216, Test RMSE: 21.3135
Training PCC: 0.8002, Validation PCC: 0.9467, Test PCC: 0.6469
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.2280, Validation Total Loss: 9.6751, Test Total Loss: 20.6907
Training Regression Loss: 10.7971, Validation Regression Loss: 9.6751, Test Regression Loss: 20.6907
Training Classification Loss: 2.1549, Validation Classification Loss: 9.6751, Test Classification Loss: 20.6907
Training RMSE: 10.3655, Validation RMSE: 9.0129, Test RMSE: 19.3641
Training PCC: 0.9529, Validation PCC: 0.9647, Test PCC: 0.6568
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 9.9656, Validation Total Loss: 8.2933, Test Total Loss: 20.0151
Training Regression Loss: 9.5309, Validation Regression Loss: 8.2933, Test Regression Loss: 20.0151
Training Classification Loss: 2.1735, Validation Classification Loss: 8.2933, Test Classification Loss: 20.0151
Training RMSE: 9.1281, Validation RMSE: 7.8857, Test RMSE: 19.0956
Training PCC: 0.9641, Validation PCC: 0.9736, Test PCC: 0.6310
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.2465, Validation Total Loss: 7.7685, Test Total Loss: 20.4872
Training Regression Loss: 8.8043, Validation Regression Loss: 7.7685, Test Regression Loss: 20.4872
Training Classification Loss: 2.2112, Validation Classification Loss: 7.7685, Test Classification Loss: 20.4872
Training RMSE: 8.4044, Validation RMSE: 7.4083, Test RMSE: 19.5231
Training PCC: 0.9691, Validation PCC: 0.9775, Test PCC: 0.6613
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 8.6265, Validation Total Loss: 7.3212, Test Total Loss: 20.1194
Training Regression Loss: 8.1743, Validation Regression Loss: 7.3212, Test Regression Loss: 20.1194
Training Classification Loss: 2.2611, Validation Classification Loss: 7.3212, Test Classification Loss: 20.1194
Training RMSE: 7.8215, Validation RMSE: 6.9475, Test RMSE: 18.9858
Training PCC: 0.9733, Validation PCC: 0.9797, Test PCC: 0.6626
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.1608, Validation Total Loss: 7.3152, Test Total Loss: 19.1807
Training Regression Loss: 7.6970, Validation Regression Loss: 7.3152, Test Regression Loss: 19.1807
Training Classification Loss: 2.3187, Validation Classification Loss: 7.3152, Test Classification Loss: 19.1807
Training RMSE: 7.3621, Validation RMSE: 6.8882, Test RMSE: 17.5726
Training PCC: 0.9760, Validation PCC: 0.9802, Test PCC: 0.7058
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 7.7806, Validation Total Loss: 6.5594, Test Total Loss: 20.2031
Training Regression Loss: 7.2963, Validation Regression Loss: 6.5594, Test Regression Loss: 20.2031
Training Classification Loss: 2.4219, Validation Classification Loss: 6.5594, Test Classification Loss: 20.2031
Training RMSE: 6.9786, Validation RMSE: 6.2478, Test RMSE: 18.8456
Training PCC: 0.9783, Validation PCC: 0.9836, Test PCC: 0.6922
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 7.4474, Validation Total Loss: 6.9724, Test Total Loss: 20.3948
Training Regression Loss: 6.9540, Validation Regression Loss: 6.9724, Test Regression Loss: 20.3948
Training Classification Loss: 2.4670, Validation Classification Loss: 6.9724, Test Classification Loss: 20.3948
Training RMSE: 6.6607, Validation RMSE: 6.5201, Test RMSE: 19.1570
Training PCC: 0.9801, Validation PCC: 0.9827, Test PCC: 0.6826
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 7.4807, Validation Total Loss: 6.9970, Test Total Loss: 18.4829
Training Regression Loss: 6.9570, Validation Regression Loss: 6.9970, Test Regression Loss: 18.4829
Training Classification Loss: 2.6184, Validation Classification Loss: 6.9970, Test Classification Loss: 18.4829
Training RMSE: 6.6418, Validation RMSE: 6.4856, Test RMSE: 17.6680
Training PCC: 0.9801, Validation PCC: 0.9847, Test PCC: 0.6802
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.1313, Validation Total Loss: 6.1164, Test Total Loss: 19.4563
Training Regression Loss: 6.6340, Validation Regression Loss: 6.1164, Test Regression Loss: 19.4563
Training Classification Loss: 2.4864, Validation Classification Loss: 6.1164, Test Classification Loss: 19.4563
Training RMSE: 6.3459, Validation RMSE: 5.8024, Test RMSE: 18.2010
Training PCC: 0.9820, Validation PCC: 0.9857, Test PCC: 0.6883
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 6.8795, Validation Total Loss: 6.0071, Test Total Loss: 19.5400
Training Regression Loss: 6.3683, Validation Regression Loss: 6.0071, Test Regression Loss: 19.5400
Training Classification Loss: 2.5557, Validation Classification Loss: 6.0071, Test Classification Loss: 19.5400
Training RMSE: 6.0956, Validation RMSE: 5.7404, Test RMSE: 18.4411
Training PCC: 0.9830, Validation PCC: 0.9863, Test PCC: 0.6894
Checkpoint saved for epoch 11


Epoch 12/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 6.8335, Validation Total Loss: 5.7070, Test Total Loss: 18.8959
Training Regression Loss: 6.2853, Validation Regression Loss: 5.7070, Test Regression Loss: 18.8959
Training Classification Loss: 2.7407, Validation Classification Loss: 5.7070, Test Classification Loss: 18.8959
Training RMSE: 6.0321, Validation RMSE: 5.4668, Test RMSE: 17.9663
Training PCC: 0.9835, Validation PCC: 0.9863, Test PCC: 0.6887
Checkpoint saved for epoch 12


Epoch 13/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 6.6803, Validation Total Loss: 5.8897, Test Total Loss: 20.8904
Training Regression Loss: 6.1455, Validation Regression Loss: 5.8897, Test Regression Loss: 20.8904
Training Classification Loss: 2.6742, Validation Classification Loss: 5.8897, Test Classification Loss: 20.8904
Training RMSE: 5.8982, Validation RMSE: 5.5108, Test RMSE: 19.7289
Training PCC: 0.9842, Validation PCC: 0.9871, Test PCC: 0.6708
Checkpoint saved for epoch 13


Epoch 14/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 6.5857, Validation Total Loss: 5.8313, Test Total Loss: 20.4258
Training Regression Loss: 6.0340, Validation Regression Loss: 5.8313, Test Regression Loss: 20.4258
Training Classification Loss: 2.7589, Validation Classification Loss: 5.8313, Test Classification Loss: 20.4258
Training RMSE: 5.7904, Validation RMSE: 5.6225, Test RMSE: 19.4735
Training PCC: 0.9848, Validation PCC: 0.9874, Test PCC: 0.6838
Checkpoint saved for epoch 14


Epoch 15/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 6.6634, Validation Total Loss: 5.8997, Test Total Loss: 19.2758
Training Regression Loss: 6.1041, Validation Regression Loss: 5.8997, Test Regression Loss: 19.2758
Training Classification Loss: 2.7961, Validation Classification Loss: 5.8997, Test Classification Loss: 19.2758
Training RMSE: 5.8515, Validation RMSE: 5.5590, Test RMSE: 18.2788
Training PCC: 0.9841, Validation PCC: 0.9875, Test PCC: 0.6479
Checkpoint saved for epoch 15


Epoch 16/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 6.3565, Validation Total Loss: 5.5379, Test Total Loss: 19.1445
Training Regression Loss: 5.8062, Validation Regression Loss: 5.5379, Test Regression Loss: 19.1445
Training Classification Loss: 2.7516, Validation Classification Loss: 5.5379, Test Classification Loss: 19.1445
Training RMSE: 5.5626, Validation RMSE: 5.2500, Test RMSE: 18.2069
Training PCC: 0.9857, Validation PCC: 0.9884, Test PCC: 0.6767
Checkpoint saved for epoch 16


Epoch 17/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 6.3022, Validation Total Loss: 5.6777, Test Total Loss: 19.8426
Training Regression Loss: 5.7442, Validation Regression Loss: 5.6777, Test Regression Loss: 19.8426
Training Classification Loss: 2.7903, Validation Classification Loss: 5.6777, Test Classification Loss: 19.8426
Training RMSE: 5.5193, Validation RMSE: 5.3687, Test RMSE: 18.8226
Training PCC: 0.9859, Validation PCC: 0.9879, Test PCC: 0.6631
Checkpoint saved for epoch 17


Epoch 18/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 6.1849, Validation Total Loss: 5.6781, Test Total Loss: 19.7791
Training Regression Loss: 5.6521, Validation Regression Loss: 5.6781, Test Regression Loss: 19.7791
Training Classification Loss: 2.6639, Validation Classification Loss: 5.6781, Test Classification Loss: 19.7791
Training RMSE: 5.4337, Validation RMSE: 5.4348, Test RMSE: 18.5829
Training PCC: 0.9863, Validation PCC: 0.9892, Test PCC: 0.7105
Checkpoint saved for epoch 18


Epoch 19/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 5.9955, Validation Total Loss: 5.1330, Test Total Loss: 18.7286
Training Regression Loss: 5.4502, Validation Regression Loss: 5.1330, Test Regression Loss: 18.7286
Training Classification Loss: 2.7264, Validation Classification Loss: 5.1330, Test Classification Loss: 18.7286
Training RMSE: 5.2494, Validation RMSE: 4.8608, Test RMSE: 17.5710
Training PCC: 0.9872, Validation PCC: 0.9897, Test PCC: 0.7030
Checkpoint saved for epoch 19


Epoch 20/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 5.9248, Validation Total Loss: 5.1575, Test Total Loss: 19.9298
Training Regression Loss: 5.3787, Validation Regression Loss: 5.1575, Test Regression Loss: 19.9298
Training Classification Loss: 2.7305, Validation Classification Loss: 5.1575, Test Classification Loss: 19.9298
Training RMSE: 5.1831, Validation RMSE: 4.9033, Test RMSE: 19.0182
Training PCC: 0.9876, Validation PCC: 0.9898, Test PCC: 0.6755
Checkpoint saved for epoch 20


Epoch 21/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 21, Training Total Loss: 6.0069, Validation Total Loss: 5.2439, Test Total Loss: 18.3703
Training Regression Loss: 5.4378, Validation Regression Loss: 5.2439, Test Regression Loss: 18.3703
Training Classification Loss: 2.8454, Validation Classification Loss: 5.2439, Test Classification Loss: 18.3703
Training RMSE: 5.2234, Validation RMSE: 4.9328, Test RMSE: 17.2023
Training PCC: 0.9873, Validation PCC: 0.9894, Test PCC: 0.7185
Checkpoint saved for epoch 21


Epoch 22/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 22, Training Total Loss: 5.8789, Validation Total Loss: 6.0594, Test Total Loss: 19.4141
Training Regression Loss: 5.3314, Validation Regression Loss: 6.0594, Test Regression Loss: 19.4141
Training Classification Loss: 2.7370, Validation Classification Loss: 6.0594, Test Classification Loss: 19.4141
Training RMSE: 5.1254, Validation RMSE: 5.6104, Test RMSE: 18.3158
Training PCC: 0.9878, Validation PCC: 0.9871, Test PCC: 0.7027
Checkpoint saved for epoch 22


Epoch 23/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 23, Training Total Loss: 6.1458, Validation Total Loss: 5.6081, Test Total Loss: 19.6233
Training Regression Loss: 5.5057, Validation Regression Loss: 5.6081, Test Regression Loss: 19.6233
Training Classification Loss: 3.2004, Validation Classification Loss: 5.6081, Test Classification Loss: 19.6233
Training RMSE: 5.2618, Validation RMSE: 5.2726, Test RMSE: 18.6601
Training PCC: 0.9873, Validation PCC: 0.9879, Test PCC: 0.6897
Checkpoint saved for epoch 23


Epoch 24/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 24, Training Total Loss: 5.8652, Validation Total Loss: 5.3305, Test Total Loss: 20.0553
Training Regression Loss: 5.2217, Validation Regression Loss: 5.3305, Test Regression Loss: 20.0553
Training Classification Loss: 3.2175, Validation Classification Loss: 5.3305, Test Classification Loss: 20.0553
Training RMSE: 5.0426, Validation RMSE: 5.1527, Test RMSE: 19.0645
Training PCC: 0.9883, Validation PCC: 0.9902, Test PCC: 0.6982
Checkpoint saved for epoch 24


Epoch 25/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 25, Training Total Loss: 5.6634, Validation Total Loss: 5.0703, Test Total Loss: 19.5189
Training Regression Loss: 5.0966, Validation Regression Loss: 5.0703, Test Regression Loss: 19.5189
Training Classification Loss: 2.8338, Validation Classification Loss: 5.0703, Test Classification Loss: 19.5189
Training RMSE: 4.9136, Validation RMSE: 4.8365, Test RMSE: 18.4136
Training PCC: 0.9887, Validation PCC: 0.9892, Test PCC: 0.6802
Checkpoint saved for epoch 25


Epoch 26/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 26, Training Total Loss: 5.4893, Validation Total Loss: 4.7755, Test Total Loss: 18.7924
Training Regression Loss: 4.9064, Validation Regression Loss: 4.7755, Test Regression Loss: 18.7924
Training Classification Loss: 2.9145, Validation Classification Loss: 4.7755, Test Classification Loss: 18.7924
Training RMSE: 4.7451, Validation RMSE: 4.5704, Test RMSE: 17.7777
Training PCC: 0.9895, Validation PCC: 0.9909, Test PCC: 0.6967
Checkpoint saved for epoch 26


Epoch 27/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 27, Training Total Loss: 5.5457, Validation Total Loss: 4.8479, Test Total Loss: 18.6396
Training Regression Loss: 4.9785, Validation Regression Loss: 4.8479, Test Regression Loss: 18.6396
Training Classification Loss: 2.8362, Validation Classification Loss: 4.8479, Test Classification Loss: 18.6396
Training RMSE: 4.8080, Validation RMSE: 4.6035, Test RMSE: 17.5326
Training PCC: 0.9893, Validation PCC: 0.9907, Test PCC: 0.7093
Checkpoint saved for epoch 27


Epoch 28/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 28, Training Total Loss: 5.8201, Validation Total Loss: 4.8518, Test Total Loss: 18.9668
Training Regression Loss: 5.2172, Validation Regression Loss: 4.8518, Test Regression Loss: 18.9668
Training Classification Loss: 3.0142, Validation Classification Loss: 4.8518, Test Classification Loss: 18.9668
Training RMSE: 5.0243, Validation RMSE: 4.6441, Test RMSE: 17.8504
Training PCC: 0.9884, Validation PCC: 0.9907, Test PCC: 0.7050
Checkpoint saved for epoch 28


Epoch 29/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 29, Training Total Loss: 5.4439, Validation Total Loss: 4.9243, Test Total Loss: 18.7668
Training Regression Loss: 4.9072, Validation Regression Loss: 4.9243, Test Regression Loss: 18.7668
Training Classification Loss: 2.6837, Validation Classification Loss: 4.9243, Test Classification Loss: 18.7668
Training RMSE: 4.7461, Validation RMSE: 4.6835, Test RMSE: 17.6228
Training PCC: 0.9895, Validation PCC: 0.9909, Test PCC: 0.7347
Checkpoint saved for epoch 29


Epoch 30/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 30, Training Total Loss: 5.3346, Validation Total Loss: 4.9917, Test Total Loss: 20.0297
Training Regression Loss: 4.7702, Validation Regression Loss: 4.9917, Test Regression Loss: 20.0297
Training Classification Loss: 2.8218, Validation Classification Loss: 4.9917, Test Classification Loss: 20.0297
Training RMSE: 4.6140, Validation RMSE: 4.8601, Test RMSE: 18.8296
Training PCC: 0.9902, Validation PCC: 0.9906, Test PCC: 0.7284
Checkpoint saved for epoch 30


Epoch 31/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 31, Training Total Loss: 5.2774, Validation Total Loss: 4.8141, Test Total Loss: 20.5143
Training Regression Loss: 4.7012, Validation Regression Loss: 4.8141, Test Regression Loss: 20.5143
Training Classification Loss: 2.8807, Validation Classification Loss: 4.8141, Test Classification Loss: 20.5143
Training RMSE: 4.5595, Validation RMSE: 4.5863, Test RMSE: 19.0690
Training PCC: 0.9905, Validation PCC: 0.9914, Test PCC: 0.6928
Checkpoint saved for epoch 31


Epoch 32/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 32, Training Total Loss: 5.0799, Validation Total Loss: 4.8005, Test Total Loss: 19.5580
Training Regression Loss: 4.5141, Validation Regression Loss: 4.8005, Test Regression Loss: 19.5580
Training Classification Loss: 2.8291, Validation Classification Loss: 4.8005, Test Classification Loss: 19.5580
Training RMSE: 4.3871, Validation RMSE: 4.5585, Test RMSE: 18.3503
Training PCC: 0.9911, Validation PCC: 0.9909, Test PCC: 0.6875
Checkpoint saved for epoch 32


Epoch 33/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 33, Training Total Loss: 5.1849, Validation Total Loss: 5.0274, Test Total Loss: 20.0567
Training Regression Loss: 4.6056, Validation Regression Loss: 5.0274, Test Regression Loss: 20.0567
Training Classification Loss: 2.8963, Validation Classification Loss: 5.0274, Test Classification Loss: 20.0567
Training RMSE: 4.4650, Validation RMSE: 4.8437, Test RMSE: 18.7678
Training PCC: 0.9907, Validation PCC: 0.9913, Test PCC: 0.7053
Checkpoint saved for epoch 33


Epoch 34/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 34, Training Total Loss: 5.2603, Validation Total Loss: 4.7464, Test Total Loss: 19.1991
Training Regression Loss: 4.6941, Validation Regression Loss: 4.7464, Test Regression Loss: 19.1991
Training Classification Loss: 2.8311, Validation Classification Loss: 4.7464, Test Classification Loss: 19.1991
Training RMSE: 4.5286, Validation RMSE: 4.5830, Test RMSE: 18.1677
Training PCC: 0.9903, Validation PCC: 0.9904, Test PCC: 0.6986
Checkpoint saved for epoch 34


Epoch 35/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 35, Training Total Loss: 5.0935, Validation Total Loss: 4.6592, Test Total Loss: 19.4103
Training Regression Loss: 4.4875, Validation Regression Loss: 4.6592, Test Regression Loss: 19.4103
Training Classification Loss: 3.0300, Validation Classification Loss: 4.6592, Test Classification Loss: 19.4103
Training RMSE: 4.3611, Validation RMSE: 4.4708, Test RMSE: 18.3820
Training PCC: 0.9911, Validation PCC: 0.9916, Test PCC: 0.6767
Checkpoint saved for epoch 35


Epoch 36/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 36, Training Total Loss: 5.0948, Validation Total Loss: 4.4839, Test Total Loss: 19.4193
Training Regression Loss: 4.4929, Validation Regression Loss: 4.4839, Test Regression Loss: 19.4193
Training Classification Loss: 3.0097, Validation Classification Loss: 4.4839, Test Classification Loss: 19.4193
Training RMSE: 4.3708, Validation RMSE: 4.3010, Test RMSE: 18.1719
Training PCC: 0.9913, Validation PCC: 0.9916, Test PCC: 0.6994
Checkpoint saved for epoch 36


Epoch 37/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 37, Training Total Loss: 5.3389, Validation Total Loss: 4.7186, Test Total Loss: 20.0997
Training Regression Loss: 4.7041, Validation Regression Loss: 4.7186, Test Regression Loss: 20.0997
Training Classification Loss: 3.1738, Validation Classification Loss: 4.7186, Test Classification Loss: 20.0997
Training RMSE: 4.5616, Validation RMSE: 4.5645, Test RMSE: 19.2355
Training PCC: 0.9902, Validation PCC: 0.9913, Test PCC: 0.6561
Checkpoint saved for epoch 37


Epoch 38/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 38, Training Total Loss: 5.0526, Validation Total Loss: 4.7119, Test Total Loss: 19.0768
Training Regression Loss: 4.4466, Validation Regression Loss: 4.7119, Test Regression Loss: 19.0768
Training Classification Loss: 3.0302, Validation Classification Loss: 4.7119, Test Classification Loss: 19.0768
Training RMSE: 4.3211, Validation RMSE: 4.5202, Test RMSE: 18.2253
Training PCC: 0.9914, Validation PCC: 0.9916, Test PCC: 0.6786
Checkpoint saved for epoch 38


Epoch 39/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 39, Training Total Loss: 4.8873, Validation Total Loss: 4.7123, Test Total Loss: 19.5376
Training Regression Loss: 4.2848, Validation Regression Loss: 4.7123, Test Regression Loss: 19.5376
Training Classification Loss: 3.0126, Validation Classification Loss: 4.7123, Test Classification Loss: 19.5376
Training RMSE: 4.1809, Validation RMSE: 4.5388, Test RMSE: 18.5620
Training PCC: 0.9919, Validation PCC: 0.9921, Test PCC: 0.6884
Checkpoint saved for epoch 39


Epoch 40/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 40, Training Total Loss: 4.8422, Validation Total Loss: 5.0468, Test Total Loss: 19.7694
Training Regression Loss: 4.2439, Validation Regression Loss: 5.0468, Test Regression Loss: 19.7694
Training Classification Loss: 2.9913, Validation Classification Loss: 5.0468, Test Classification Loss: 19.7694
Training RMSE: 4.1386, Validation RMSE: 4.8914, Test RMSE: 18.6695
Training PCC: 0.9921, Validation PCC: 0.9913, Test PCC: 0.6915
Checkpoint saved for epoch 40


Epoch 41/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 41, Training Total Loss: 4.7663, Validation Total Loss: 4.5161, Test Total Loss: 19.4549
Training Regression Loss: 4.1890, Validation Regression Loss: 4.5161, Test Regression Loss: 19.4549
Training Classification Loss: 2.8864, Validation Classification Loss: 4.5161, Test Classification Loss: 19.4549
Training RMSE: 4.0851, Validation RMSE: 4.3331, Test RMSE: 18.2825
Training PCC: 0.9924, Validation PCC: 0.9922, Test PCC: 0.6834
Checkpoint saved for epoch 41


Epoch 42/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 42, Training Total Loss: 4.7184, Validation Total Loss: 4.6248, Test Total Loss: 18.7060
Training Regression Loss: 4.1174, Validation Regression Loss: 4.6248, Test Regression Loss: 18.7060
Training Classification Loss: 3.0049, Validation Classification Loss: 4.6248, Test Classification Loss: 18.7060
Training RMSE: 4.0129, Validation RMSE: 4.3439, Test RMSE: 17.6101
Training PCC: 0.9926, Validation PCC: 0.9918, Test PCC: 0.6918
Checkpoint saved for epoch 42


Epoch 43/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 43, Training Total Loss: 4.8827, Validation Total Loss: 5.2925, Test Total Loss: 19.8061
Training Regression Loss: 4.2486, Validation Regression Loss: 5.2925, Test Regression Loss: 19.8061
Training Classification Loss: 3.1702, Validation Classification Loss: 5.2925, Test Classification Loss: 19.8061
Training RMSE: 4.1320, Validation RMSE: 4.9736, Test RMSE: 18.7762
Training PCC: 0.9923, Validation PCC: 0.9893, Test PCC: 0.6847
Checkpoint saved for epoch 43


Epoch 44/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 44, Training Total Loss: 5.0220, Validation Total Loss: 4.8537, Test Total Loss: 19.7454
Training Regression Loss: 4.4255, Validation Regression Loss: 4.8537, Test Regression Loss: 19.7454
Training Classification Loss: 2.9823, Validation Classification Loss: 4.8537, Test Classification Loss: 19.7454
Training RMSE: 4.2866, Validation RMSE: 4.6386, Test RMSE: 18.6850
Training PCC: 0.9914, Validation PCC: 0.9914, Test PCC: 0.6958
Checkpoint saved for epoch 44


Epoch 45/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 45, Training Total Loss: 4.9205, Validation Total Loss: 4.5346, Test Total Loss: 19.5543
Training Regression Loss: 4.3172, Validation Regression Loss: 4.5346, Test Regression Loss: 19.5543
Training Classification Loss: 3.0165, Validation Classification Loss: 4.5346, Test Classification Loss: 19.5543
Training RMSE: 4.1838, Validation RMSE: 4.3268, Test RMSE: 18.5597
Training PCC: 0.9917, Validation PCC: 0.9925, Test PCC: 0.6711
Checkpoint saved for epoch 45


Epoch 46/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 46, Training Total Loss: 4.8995, Validation Total Loss: 4.5326, Test Total Loss: 20.2430
Training Regression Loss: 4.3016, Validation Regression Loss: 4.5326, Test Regression Loss: 20.2430
Training Classification Loss: 2.9894, Validation Classification Loss: 4.5326, Test Classification Loss: 20.2430
Training RMSE: 4.1705, Validation RMSE: 4.3141, Test RMSE: 18.9575
Training PCC: 0.9918, Validation PCC: 0.9921, Test PCC: 0.6570
Checkpoint saved for epoch 46
Stopping early after 46 epochs
Total training time: 6866.20 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.2_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_0.4_wl100_ol75 with alpha: 0.4
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 20.0237, Validation Total Loss: 12.0789, Test Total Loss: 22.5330
Training Regression Loss: 18.9956, Validation Regression Loss: 12.0789, Test Regression Loss: 22.5330
Training Classification Loss: 2.5703, Validation Classification Loss: 12.0789, Test Classification Loss: 22.5330
Training RMSE: 18.5153, Validation RMSE: 11.6816, Test RMSE: 21.3630
Training PCC: 0.7940, Validation PCC: 0.9433, Test PCC: 0.6442
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 11.6656, Validation Total Loss: 9.7342, Test Total Loss: 19.0850
Training Regression Loss: 10.6279, Validation Regression Loss: 9.7342, Test Regression Loss: 19.0850
Training Classification Loss: 2.5943, Validation Classification Loss: 9.7342, Test Classification Loss: 19.0850
Training RMSE: 10.2212, Validation RMSE: 9.3039, Test RMSE: 18.2137
Training PCC: 0.9541, Validation PCC: 0.9598, Test PCC: 0.6343
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 10.4938, Validation Total Loss: 9.3896, Test Total Loss: 21.5919
Training Regression Loss: 9.4707, Validation Regression Loss: 9.3896, Test Regression Loss: 21.5919
Training Classification Loss: 2.5578, Validation Classification Loss: 9.3896, Test Classification Loss: 21.5919
Training RMSE: 9.0702, Validation RMSE: 8.9797, Test RMSE: 20.6320
Training PCC: 0.9644, Validation PCC: 0.9695, Test PCC: 0.5997
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 9.8089, Validation Total Loss: 8.0219, Test Total Loss: 20.0607
Training Regression Loss: 8.7381, Validation Regression Loss: 8.0219, Test Regression Loss: 20.0607
Training Classification Loss: 2.6770, Validation Classification Loss: 8.0219, Test Classification Loss: 20.0607
Training RMSE: 8.3543, Validation RMSE: 7.6211, Test RMSE: 18.6381
Training PCC: 0.9702, Validation PCC: 0.9740, Test PCC: 0.6542
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 9.1650, Validation Total Loss: 7.8140, Test Total Loss: 19.7410
Training Regression Loss: 8.1163, Validation Regression Loss: 7.8140, Test Regression Loss: 19.7410
Training Classification Loss: 2.6216, Validation Classification Loss: 7.8140, Test Classification Loss: 19.7410
Training RMSE: 7.7586, Validation RMSE: 7.4069, Test RMSE: 18.5607
Training PCC: 0.9741, Validation PCC: 0.9748, Test PCC: 0.6772
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 8.8651, Validation Total Loss: 7.3179, Test Total Loss: 20.0495
Training Regression Loss: 7.7499, Validation Regression Loss: 7.3179, Test Regression Loss: 20.0495
Training Classification Loss: 2.7880, Validation Classification Loss: 7.3179, Test Classification Loss: 20.0495
Training RMSE: 7.3968, Validation RMSE: 6.9347, Test RMSE: 18.9343
Training PCC: 0.9761, Validation PCC: 0.9771, Test PCC: 0.6737
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 8.6420, Validation Total Loss: 8.9509, Test Total Loss: 21.8831
Training Regression Loss: 7.4136, Validation Regression Loss: 8.9509, Test Regression Loss: 21.8831
Training Classification Loss: 3.0711, Validation Classification Loss: 8.9509, Test Classification Loss: 21.8831
Training RMSE: 7.1004, Validation RMSE: 8.2122, Test RMSE: 20.9568
Training PCC: 0.9778, Validation PCC: 0.9751, Test PCC: 0.6604
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 8.4344, Validation Total Loss: 6.6578, Test Total Loss: 19.7040
Training Regression Loss: 7.2784, Validation Regression Loss: 6.6578, Test Regression Loss: 19.7040
Training Classification Loss: 2.8900, Validation Classification Loss: 6.6578, Test Classification Loss: 19.7040
Training RMSE: 6.9714, Validation RMSE: 6.2753, Test RMSE: 18.5350
Training PCC: 0.9787, Validation PCC: 0.9816, Test PCC: 0.6833
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 8.1344, Validation Total Loss: 6.5049, Test Total Loss: 19.8696
Training Regression Loss: 6.9684, Validation Regression Loss: 6.5049, Test Regression Loss: 19.8696
Training Classification Loss: 2.9152, Validation Classification Loss: 6.5049, Test Classification Loss: 19.8696
Training RMSE: 6.6639, Validation RMSE: 6.1791, Test RMSE: 18.6402
Training PCC: 0.9803, Validation PCC: 0.9821, Test PCC: 0.6825
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 7.8464, Validation Total Loss: 6.8115, Test Total Loss: 18.8183
Training Regression Loss: 6.6686, Validation Regression Loss: 6.8115, Test Regression Loss: 18.8183
Training Classification Loss: 2.9444, Validation Classification Loss: 6.8115, Test Classification Loss: 18.8183
Training RMSE: 6.3926, Validation RMSE: 6.4804, Test RMSE: 17.9988
Training PCC: 0.9818, Validation PCC: 0.9829, Test PCC: 0.6915
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 7.7809, Validation Total Loss: 6.4217, Test Total Loss: 19.5361
Training Regression Loss: 6.5059, Validation Regression Loss: 6.4217, Test Regression Loss: 19.5361
Training Classification Loss: 3.1874, Validation Classification Loss: 6.4217, Test Classification Loss: 19.5361
Training RMSE: 6.2291, Validation RMSE: 6.1102, Test RMSE: 18.3570
Training PCC: 0.9826, Validation PCC: 0.9828, Test PCC: 0.7069
Checkpoint saved for epoch 11


Epoch 12/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 7.9954, Validation Total Loss: 6.4960, Test Total Loss: 19.9252
Training Regression Loss: 6.5013, Validation Regression Loss: 6.4960, Test Regression Loss: 19.9252
Training Classification Loss: 3.7351, Validation Classification Loss: 6.4960, Test Classification Loss: 19.9252
Training RMSE: 6.2449, Validation RMSE: 6.2087, Test RMSE: 18.8883
Training PCC: 0.9826, Validation PCC: 0.9817, Test PCC: 0.6585
Checkpoint saved for epoch 12


Epoch 13/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 13, Training Total Loss: 7.8680, Validation Total Loss: 6.5537, Test Total Loss: 20.3866
Training Regression Loss: 6.4685, Validation Regression Loss: 6.5537, Test Regression Loss: 20.3866
Training Classification Loss: 3.4988, Validation Classification Loss: 6.5537, Test Classification Loss: 20.3866
Training RMSE: 6.2075, Validation RMSE: 6.2356, Test RMSE: 19.0173
Training PCC: 0.9829, Validation PCC: 0.9810, Test PCC: 0.6776
Checkpoint saved for epoch 13


Epoch 14/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 14, Training Total Loss: 7.8187, Validation Total Loss: 6.0749, Test Total Loss: 19.4800
Training Regression Loss: 6.3109, Validation Regression Loss: 6.0749, Test Regression Loss: 19.4800
Training Classification Loss: 3.7694, Validation Classification Loss: 6.0749, Test Classification Loss: 19.4800
Training RMSE: 6.0707, Validation RMSE: 5.7755, Test RMSE: 18.3349
Training PCC: 0.9839, Validation PCC: 0.9842, Test PCC: 0.6974
Checkpoint saved for epoch 14


Epoch 15/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 15, Training Total Loss: 7.9486, Validation Total Loss: 6.2964, Test Total Loss: 19.8689
Training Regression Loss: 6.3049, Validation Regression Loss: 6.2964, Test Regression Loss: 19.8689
Training Classification Loss: 4.1091, Validation Classification Loss: 6.2964, Test Classification Loss: 19.8689
Training RMSE: 6.0528, Validation RMSE: 6.0145, Test RMSE: 18.6370
Training PCC: 0.9841, Validation PCC: 0.9833, Test PCC: 0.6936
Checkpoint saved for epoch 15


Epoch 16/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 16, Training Total Loss: 7.9953, Validation Total Loss: 6.4670, Test Total Loss: 18.0509
Training Regression Loss: 6.2050, Validation Regression Loss: 6.4670, Test Regression Loss: 18.0509
Training Classification Loss: 4.4758, Validation Classification Loss: 6.4670, Test Classification Loss: 18.0509
Training RMSE: 5.9662, Validation RMSE: 6.0715, Test RMSE: 17.1069
Training PCC: 0.9844, Validation PCC: 0.9848, Test PCC: 0.6878
Checkpoint saved for epoch 16


Epoch 17/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 17, Training Total Loss: 2273.4743, Validation Total Loss: 3582.1708, Test Total Loss: 3585.2530
Training Regression Loss: 586.9020, Validation Regression Loss: 3582.1708, Test Regression Loss: 3585.2530
Training Classification Loss: 4216.4306, Validation Classification Loss: 3582.1708, Test Classification Loss: 3585.2530
Training RMSE: 475.9209, Validation RMSE: 2647.6454, Test RMSE: 2651.4308
Training PCC: 0.5843, Validation PCC: 0.0249, Test PCC: 0.0602
Checkpoint saved for epoch 17


Epoch 18/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 18, Training Total Loss: 9872.6397, Validation Total Loss: 1230.3680, Test Total Loss: 1226.7986
Training Regression Loss: 2411.6351, Validation Regression Loss: 1230.3680, Test Regression Loss: 1226.7986
Training Classification Loss: 18652.5110, Validation Classification Loss: 1230.3680, Test Classification Loss: 1226.7986
Training RMSE: 1948.2382, Validation RMSE: 853.6198, Test RMSE: 853.2411
Training PCC: 0.0030, Validation PCC: 0.0011, Test PCC: 0.0036
Checkpoint saved for epoch 18


Epoch 19/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 19, Training Total Loss: 13701.4298, Validation Total Loss: 2274.8476, Test Total Loss: 2279.1062
Training Regression Loss: 3188.6920, Validation Regression Loss: 2274.8476, Test Regression Loss: 2279.1062
Training Classification Loss: 26281.8440, Validation Classification Loss: 2274.8476, Test Classification Loss: 2279.1062
Training RMSE: 2234.3846, Validation RMSE: 1490.9387, Test RMSE: 1495.2921
Training PCC: -0.0008, Validation PCC: 0.0088, Test PCC: 0.0050
Checkpoint saved for epoch 19


Epoch 20/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 20, Training Total Loss: 17867.2487, Validation Total Loss: 1586.7503, Test Total Loss: 1581.8204
Training Regression Loss: 2953.1467, Validation Regression Loss: 1586.7503, Test Regression Loss: 1581.8204
Training Classification Loss: 37285.2545, Validation Classification Loss: 1586.7503, Test Classification Loss: 1581.8204
Training RMSE: 2053.4733, Validation RMSE: 1342.9810, Test RMSE: 1336.8764
Training PCC: 0.0008, Validation PCC: 0.0002, Test PCC: 0.0057
Checkpoint saved for epoch 20


Epoch 21/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 21, Training Total Loss: 21008.8492, Validation Total Loss: 3550.0829, Test Total Loss: 3554.8271
Training Regression Loss: 5419.3785, Validation Regression Loss: 3550.0829, Test Regression Loss: 3554.8271
Training Classification Loss: 38973.6759, Validation Classification Loss: 3550.0829, Test Classification Loss: 3554.8271
Training RMSE: 3573.8810, Validation RMSE: 2660.3111, Test RMSE: 2660.5736
Training PCC: -0.0016, Validation PCC: -0.0002, Test PCC: 0.0002
Checkpoint saved for epoch 21


Epoch 22/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 22, Training Total Loss: 23208.4414, Validation Total Loss: 3569.1727, Test Total Loss: 3568.3177
Training Regression Loss: 3997.5508, Validation Regression Loss: 3569.1727, Test Regression Loss: 3568.3177
Training Classification Loss: 48027.2259, Validation Classification Loss: 3569.1727, Test Classification Loss: 3568.3177
Training RMSE: 2573.5271, Validation RMSE: 2111.6727, Test RMSE: 2105.4470
Training PCC: -0.0020, Validation PCC: 0.0089, Test PCC: -0.0032
Checkpoint saved for epoch 22


Epoch 23/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 23, Training Total Loss: 28541.6996, Validation Total Loss: 9626.0687, Test Total Loss: 8256.6703
Training Regression Loss: 5754.0291, Validation Regression Loss: 9626.0687, Test Regression Loss: 8256.6703
Training Classification Loss: 56969.1755, Validation Classification Loss: 9626.0687, Test Classification Loss: 8256.6703
Training RMSE: 4457.5427, Validation RMSE: 9268.0867, Test RMSE: 7949.4885
Training PCC: 0.0013, Validation PCC: -0.0247, Test PCC: 0.0649
Checkpoint saved for epoch 23


Epoch 24/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 24, Training Total Loss: 29815.1789, Validation Total Loss: 24512.2969, Test Total Loss: 19619.2439
Training Regression Loss: 6195.2098, Validation Regression Loss: 24512.2969, Test Regression Loss: 19619.2439
Training Classification Loss: 59049.9215, Validation Classification Loss: 24512.2969, Test Classification Loss: 19619.2439
Training RMSE: 4663.7751, Validation RMSE: 19781.1990, Test RMSE: 15838.1809
Training PCC: 0.0035, Validation PCC: -0.0464, Test PCC: -0.0767
Checkpoint saved for epoch 24
Stopping early after 24 epochs
Total training time: 3550.57 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.4_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_0.8_wl100_ol75 with alpha: 0.8
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 21.5956, Validation Total Loss: 11.7919, Test Total Loss: 21.3284
Training Regression Loss: 19.2360, Validation Regression Loss: 11.7919, Test Regression Loss: 21.3284
Training Classification Loss: 2.9496, Validation Classification Loss: 11.7919, Test Classification Loss: 21.3284
Training RMSE: 18.6820, Validation RMSE: 11.4021, Test RMSE: 19.9204
Training PCC: 0.7965, Validation PCC: 0.9423, Test PCC: 0.6201
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 13.8877, Validation Total Loss: 11.0375, Test Total Loss: 19.1820
Training Regression Loss: 11.0617, Validation Regression Loss: 11.0375, Test Regression Loss: 19.1820
Training Classification Loss: 3.5324, Validation Classification Loss: 11.0375, Test Classification Loss: 19.1820
Training RMSE: 10.6692, Validation RMSE: 10.5843, Test RMSE: 18.0388
Training PCC: 0.9496, Validation PCC: 0.9518, Test PCC: 0.7159
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 1476.5458, Validation Total Loss: 484.5984, Test Total Loss: 484.0329
Training Regression Loss: 183.4820, Validation Regression Loss: 484.5984, Test Regression Loss: 484.0329
Training Classification Loss: 1616.3297, Validation Classification Loss: 484.5984, Test Classification Loss: 484.0329
Training RMSE: 159.5806, Validation RMSE: 383.5927, Test RMSE: 382.5917
Training PCC: 0.4066, Validation PCC: 0.1673, Test PCC: 0.2083
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 3551.9167, Validation Total Loss: 164.7041, Test Total Loss: 170.0298
Training Regression Loss: 434.2995, Validation Regression Loss: 164.7041, Test Regression Loss: 170.0298
Training Classification Loss: 3897.0214, Validation Classification Loss: 164.7041, Test Classification Loss: 170.0298
Training RMSE: 368.8171, Validation RMSE: 138.2272, Test RMSE: 141.6428
Training PCC: 0.1751, Validation PCC: 0.1066, Test PCC: 0.0990
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 5604.3941, Validation Total Loss: 777.7316, Test Total Loss: 783.1429
Training Regression Loss: 680.5434, Validation Regression Loss: 777.7316, Test Regression Loss: 783.1429
Training Classification Loss: 6154.8133, Validation Classification Loss: 777.7316, Test Classification Loss: 783.1429
Training RMSE: 564.8125, Validation RMSE: 660.6203, Test RMSE: 666.3455
Training PCC: 0.1039, Validation PCC: 0.0331, Test PCC: -0.0351
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 7326.0436, Validation Total Loss: 706.3219, Test Total Loss: 704.3396
Training Regression Loss: 913.9382, Validation Regression Loss: 706.3219, Test Regression Loss: 704.3396
Training Classification Loss: 8015.1317, Validation Classification Loss: 706.3219, Test Classification Loss: 704.3396
Training RMSE: 723.1296, Validation RMSE: 670.3413, Test RMSE: 669.0749
Training PCC: 0.0863, Validation PCC: 0.0915, Test PCC: 0.0261
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 9256.0235, Validation Total Loss: 1387.9614, Test Total Loss: 1386.0386
Training Regression Loss: 1199.7741, Validation Regression Loss: 1387.9614, Test Regression Loss: 1386.0386
Training Classification Loss: 10070.3117, Validation Classification Loss: 1387.9614, Test Classification Loss: 1386.0386
Training RMSE: 994.6549, Validation RMSE: 1132.3462, Test RMSE: 1125.7303
Training PCC: 0.0930, Validation PCC: 0.0987, Test PCC: 0.0786
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 11531.1153, Validation Total Loss: 1032.7619, Test Total Loss: 1033.6938
Training Regression Loss: 1437.8496, Validation Regression Loss: 1032.7619, Test Regression Loss: 1033.6938
Training Classification Loss: 12616.5819, Validation Classification Loss: 1032.7619, Test Classification Loss: 1033.6938
Training RMSE: 1196.4943, Validation RMSE: 848.1001, Test RMSE: 852.1799
Training PCC: 0.0321, Validation PCC: 0.0550, Test PCC: 0.0553
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 14224.9544, Validation Total Loss: 1997.0755, Test Total Loss: 1994.4629
Training Regression Loss: 1525.6552, Validation Regression Loss: 1997.0755, Test Regression Loss: 1994.4629
Training Classification Loss: 15874.1238, Validation Classification Loss: 1997.0755, Test Classification Loss: 1994.4629
Training RMSE: 1220.6073, Validation RMSE: 1902.3680, Test RMSE: 1900.7866
Training PCC: 0.0739, Validation PCC: 0.1348, Test PCC: 0.0662
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 16302.2687, Validation Total Loss: 1595.8947, Test Total Loss: 1589.2047
Training Regression Loss: 1886.6137, Validation Regression Loss: 1595.8947, Test Regression Loss: 1589.2047
Training Classification Loss: 18019.5685, Validation Classification Loss: 1595.8947, Test Classification Loss: 1589.2047
Training RMSE: 1278.6681, Validation RMSE: 964.2950, Test RMSE: 958.0045
Training PCC: 0.0711, Validation PCC: 0.0836, Test PCC: 0.0576
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 17533.0260, Validation Total Loss: 1958.0659, Test Total Loss: 1953.4934
Training Regression Loss: 2017.0032, Validation Regression Loss: 1958.0659, Test Regression Loss: 1953.4934
Training Classification Loss: 19395.0282, Validation Classification Loss: 1958.0659, Test Classification Loss: 1953.4934
Training RMSE: 1376.2200, Validation RMSE: 1189.4013, Test RMSE: 1182.5087
Training PCC: 0.0748, Validation PCC: 0.0873, Test PCC: 0.0556
Checkpoint saved for epoch 11


Epoch 12/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 12, Training Total Loss: 21452.8440, Validation Total Loss: 2011.5503, Test Total Loss: 2016.9110
Training Regression Loss: 2081.2955, Validation Regression Loss: 2011.5503, Test Regression Loss: 2016.9110
Training Classification Loss: 24214.4351, Validation Classification Loss: 2011.5503, Test Classification Loss: 2016.9110
Training RMSE: 1549.8575, Validation RMSE: 1617.7235, Test RMSE: 1618.5930
Training PCC: 0.0293, Validation PCC: -0.0265, Test PCC: -0.0328
Checkpoint saved for epoch 12
Stopping early after 12 epochs
Total training time: 1778.03 seconds
Finished training for TeacherModel_DomainInvariant_alpha_0.8_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_1_wl100_ol75 with alpha: 1
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 23.0853, Validation Total Loss: 11.6062, Test Total Loss: 21.5802
Training Regression Loss: 19.5380, Validation Regression Loss: 11.6062, Test Regression Loss: 21.5802
Training Classification Loss: 3.5472, Validation Classification Loss: 11.6062, Test Classification Loss: 21.5802
Training RMSE: 19.0150, Validation RMSE: 11.2712, Test RMSE: 20.1069
Training PCC: 0.7909, Validation PCC: 0.9412, Test PCC: 0.6303
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 1068.9351, Validation Total Loss: 300.6049, Test Total Loss: 302.5575
Training Regression Loss: 124.7082, Validation Regression Loss: 300.6049, Test Regression Loss: 302.5575
Training Classification Loss: 944.2269, Validation Classification Loss: 300.6049, Test Classification Loss: 302.5575
Training RMSE: 107.5233, Validation RMSE: 295.7907, Test RMSE: 296.1453
Training PCC: 0.6552, Validation PCC: 0.0102, Test PCC: 0.0282
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 8514.1969, Validation Total Loss: 1901.1268, Test Total Loss: 1899.7509
Training Regression Loss: 979.5924, Validation Regression Loss: 1901.1268, Test Regression Loss: 1899.7509
Training Classification Loss: 7534.6045, Validation Classification Loss: 1901.1268, Test Classification Loss: 1899.7509
Training RMSE: 791.0633, Validation RMSE: 1506.2502, Test RMSE: 1501.0613
Training PCC: -0.0026, Validation PCC: 0.0046, Test PCC: -0.0004
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 14748.0483, Validation Total Loss: 564.2820, Test Total Loss: 565.6429
Training Regression Loss: 1301.0997, Validation Regression Loss: 564.2820, Test Regression Loss: 565.6429
Training Classification Loss: 13446.9486, Validation Classification Loss: 564.2820, Test Classification Loss: 565.6429
Training RMSE: 1038.6325, Validation RMSE: 507.5677, Test RMSE: 506.0655
Training PCC: -0.0015, Validation PCC: 0.0027, Test PCC: 0.0046
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 21278.1431, Validation Total Loss: 3116.8538, Test Total Loss: 3119.2970
Training Regression Loss: 2168.4088, Validation Regression Loss: 3116.8538, Test Regression Loss: 3119.2970
Training Classification Loss: 19109.7343, Validation Classification Loss: 3116.8538, Test Classification Loss: 3119.2970
Training RMSE: 1766.1345, Validation RMSE: 2673.6135, Test RMSE: 2672.4147
Training PCC: -0.0024, Validation PCC: 0.0015, Test PCC: 0.0022
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 28836.0297, Validation Total Loss: 2059.7796, Test Total Loss: 2059.3760
Training Regression Loss: 3630.1534, Validation Regression Loss: 2059.7796, Test Regression Loss: 2059.3760
Training Classification Loss: 25205.8764, Validation Classification Loss: 2059.7796, Test Classification Loss: 2059.3760
Training RMSE: 2798.1509, Validation RMSE: 1246.8459, Test RMSE: 1246.9199
Training PCC: 0.0050, Validation PCC: 0.0018, Test PCC: -0.0076
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 36738.4995, Validation Total Loss: 5183.5795, Test Total Loss: 5188.9703
Training Regression Loss: 3199.2424, Validation Regression Loss: 5183.5795, Test Regression Loss: 5188.9703
Training Classification Loss: 33539.2571, Validation Classification Loss: 5183.5795, Test Classification Loss: 5188.9703
Training RMSE: 2336.5943, Validation RMSE: 3132.6539, Test RMSE: 3137.2624
Training PCC: -0.0014, Validation PCC: 0.0022, Test PCC: 0.0039
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 47414.7438, Validation Total Loss: 1742.2922, Test Total Loss: 1739.4447
Training Regression Loss: 5035.3607, Validation Regression Loss: 1742.2922, Test Regression Loss: 1739.4447
Training Classification Loss: 42379.3831, Validation Classification Loss: 1742.2922, Test Classification Loss: 1739.4447
Training RMSE: 3623.4820, Validation RMSE: 1582.2631, Test RMSE: 1577.1028
Training PCC: 0.0073, Validation PCC: -0.0019, Test PCC: -0.0114
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 49778.0186, Validation Total Loss: 3116.9664, Test Total Loss: 3118.2392
Training Regression Loss: 5307.7192, Validation Regression Loss: 3116.9664, Test Regression Loss: 3118.2392
Training Classification Loss: 44470.2996, Validation Classification Loss: 3116.9664, Test Classification Loss: 3118.2392
Training RMSE: 4000.6663, Validation RMSE: 2077.8206, Test RMSE: 2083.2292
Training PCC: 0.0024, Validation PCC: 0.0021, Test PCC: -0.0118
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 55043.4616, Validation Total Loss: 3469.2265, Test Total Loss: 3474.4665
Training Regression Loss: 5266.7366, Validation Regression Loss: 3469.2265, Test Regression Loss: 3474.4665
Training Classification Loss: 49776.7248, Validation Classification Loss: 3469.2265, Test Classification Loss: 3474.4665
Training RMSE: 3950.3561, Validation RMSE: 2044.3897, Test RMSE: 2044.5111
Training PCC: -0.0037, Validation PCC: 0.0004, Test PCC: -0.0040
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 69331.7399, Validation Total Loss: 5150.7257, Test Total Loss: 5150.2833
Training Regression Loss: 5380.0273, Validation Regression Loss: 5150.7257, Test Regression Loss: 5150.2833
Training Classification Loss: 63951.7126, Validation Classification Loss: 5150.7257, Test Classification Loss: 5150.2833
Training RMSE: 3707.0760, Validation RMSE: 3064.4469, Test RMSE: 3067.4579
Training PCC: 0.0020, Validation PCC: 0.0033, Test PCC: 0.0063
Checkpoint saved for epoch 11
Stopping early after 11 epochs
Total training time: 1648.96 seconds
Finished training for TeacherModel_DomainInvariant_alpha_1_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_2_wl100_ol75 with alpha: 2
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 386.4402, Validation Total Loss: 85.9289, Test Total Loss: 77.4090
Training Regression Loss: 42.9947, Validation Regression Loss: 85.9289, Test Regression Loss: 77.4090
Training Classification Loss: 171.7228, Validation Classification Loss: 85.9289, Test Classification Loss: 77.4090
Training RMSE: 41.2508, Validation RMSE: 80.0031, Test RMSE: 73.2522
Training PCC: 0.3762, Validation PCC: 0.2322, Test PCC: 0.3470
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 3209.7442, Validation Total Loss: 227.2665, Test Total Loss: 230.1127
Training Regression Loss: 188.8788, Validation Regression Loss: 227.2665, Test Regression Loss: 230.1127
Training Classification Loss: 1510.4327, Validation Classification Loss: 227.2665, Test Classification Loss: 230.1127
Training RMSE: 162.7107, Validation RMSE: 220.7257, Test RMSE: 220.8336
Training PCC: 0.0493, Validation PCC: 0.0239, Test PCC: 0.1126
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 8161.0840, Validation Total Loss: 1232.6169, Test Total Loss: 1237.0006
Training Regression Loss: 540.7884, Validation Regression Loss: 1232.6169, Test Regression Loss: 1237.0006
Training Classification Loss: 3810.1478, Validation Classification Loss: 1232.6169, Test Classification Loss: 1237.0006
Training RMSE: 432.1012, Validation RMSE: 744.4074, Test RMSE: 748.7883
Training PCC: 0.0024, Validation PCC: 0.0089, Test PCC: 0.0175
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 15585.8441, Validation Total Loss: 646.7623, Test Total Loss: 640.5292
Training Regression Loss: 949.2013, Validation Regression Loss: 646.7623, Test Regression Loss: 640.5292
Training Classification Loss: 7318.3214, Validation Classification Loss: 646.7623, Test Classification Loss: 640.5292
Training RMSE: 698.7584, Validation RMSE: 430.4098, Test RMSE: 422.3279
Training PCC: 0.0013, Validation PCC: 0.0070, Test PCC: 0.0125
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 25429.2041, Validation Total Loss: 585.6599, Test Total Loss: 585.3585
Training Regression Loss: 1462.5514, Validation Regression Loss: 585.6599, Test Regression Loss: 585.3585
Training Classification Loss: 11983.3263, Validation Classification Loss: 585.6599, Test Classification Loss: 585.3585
Training RMSE: 1161.1949, Validation RMSE: 453.0431, Test RMSE: 451.4570
Training PCC: -0.0002, Validation PCC: 0.0017, Test PCC: -0.0022
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 36438.3529, Validation Total Loss: 428.5841, Test Total Loss: 425.7408
Training Regression Loss: 2056.4791, Validation Regression Loss: 428.5841, Test Regression Loss: 425.7408
Training Classification Loss: 17190.9368, Validation Classification Loss: 428.5841, Test Classification Loss: 425.7408
Training RMSE: 1750.3207, Validation RMSE: 410.0835, Test RMSE: 408.8946
Training PCC: 0.0013, Validation PCC: 0.0030, Test PCC: 0.0056
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 41871.4770, Validation Total Loss: 1514.7720, Test Total Loss: 1516.7581
Training Regression Loss: 2416.3935, Validation Regression Loss: 1514.7720, Test Regression Loss: 1516.7581
Training Classification Loss: 19727.5417, Validation Classification Loss: 1514.7720, Test Classification Loss: 1516.7581
Training RMSE: 1953.1055, Validation RMSE: 1257.8616, Test RMSE: 1256.5155
Training PCC: 0.0013, Validation PCC: -0.0068, Test PCC: 0.0144
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 52850.9589, Validation Total Loss: 2157.8787, Test Total Loss: 2157.2458
Training Regression Loss: 3128.1851, Validation Regression Loss: 2157.8787, Test Regression Loss: 2157.2458
Training Classification Loss: 24861.3869, Validation Classification Loss: 2157.8787, Test Classification Loss: 2157.2458
Training RMSE: 2478.7075, Validation RMSE: 1346.1052, Test RMSE: 1345.8299
Training PCC: 0.0015, Validation PCC: 0.0014, Test PCC: 0.0024
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 57877.0032, Validation Total Loss: 1595.2377, Test Total Loss: 1601.1382
Training Regression Loss: 3085.5066, Validation Regression Loss: 1595.2377, Test Regression Loss: 1601.1382
Training Classification Loss: 27395.7483, Validation Classification Loss: 1595.2377, Test Classification Loss: 1601.1382
Training RMSE: 2316.3361, Validation RMSE: 955.4292, Test RMSE: 957.6393
Training PCC: 0.0076, Validation PCC: 0.0007, Test PCC: -0.0202
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 71679.6110, Validation Total Loss: 2302.5908, Test Total Loss: 2298.5816
Training Regression Loss: 3161.5238, Validation Regression Loss: 2302.5908, Test Regression Loss: 2298.5816
Training Classification Loss: 34259.0436, Validation Classification Loss: 2302.5908, Test Classification Loss: 2298.5816
Training RMSE: 2319.2610, Validation RMSE: 1524.5478, Test RMSE: 1524.9810
Training PCC: -0.0013, Validation PCC: 0.0056, Test PCC: 0.0329
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 79005.6261, Validation Total Loss: 2591.7015, Test Total Loss: 2594.9762
Training Regression Loss: 4140.6778, Validation Regression Loss: 2591.7015, Test Regression Loss: 2594.9762
Training Classification Loss: 37432.4741, Validation Classification Loss: 2591.7015, Test Classification Loss: 2594.9762
Training RMSE: 3215.4398, Validation RMSE: 2368.0419, Test RMSE: 2373.6078
Training PCC: 0.0020, Validation PCC: -0.0000, Test PCC: 0.0215
Checkpoint saved for epoch 11
Stopping early after 11 epochs
Total training time: 1646.66 seconds
Finished training for TeacherModel_DomainInvariant_alpha_2_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_4_wl100_ol75 with alpha: 4
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 562.5370, Validation Total Loss: 73.5681, Test Total Loss: 76.6259
Training Regression Loss: 46.6761, Validation Regression Loss: 73.5681, Test Regression Loss: 76.6259
Training Classification Loss: 128.9652, Validation Classification Loss: 73.5681, Test Classification Loss: 76.6259
Training RMSE: 43.9123, Validation RMSE: 71.4168, Test RMSE: 74.2408
Training PCC: 0.0297, Validation PCC: 0.0033, Test PCC: 0.0038
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 8497.0457, Validation Total Loss: 271.8379, Test Total Loss: 268.9716
Training Regression Loss: 233.3758, Validation Regression Loss: 271.8379, Test Regression Loss: 268.9716
Training Classification Loss: 2065.9175, Validation Classification Loss: 271.8379, Test Classification Loss: 268.9716
Training RMSE: 202.7562, Validation RMSE: 235.8578, Test RMSE: 235.1727
Training PCC: -0.0002, Validation PCC: -0.0018, Test PCC: -0.0018
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 25344.3676, Validation Total Loss: 702.7659, Test Total Loss: 703.7999
Training Regression Loss: 618.4338, Validation Regression Loss: 702.7659, Test Regression Loss: 703.7999
Training Classification Loss: 6181.4834, Validation Classification Loss: 702.7659, Test Classification Loss: 703.7999
Training RMSE: 498.5074, Validation RMSE: 503.0507, Test RMSE: 507.1367
Training PCC: -0.0049, Validation PCC: 0.0011, Test PCC: 0.0011
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 44389.9801, Validation Total Loss: 1463.7567, Test Total Loss: 1470.4822
Training Regression Loss: 1044.2476, Validation Regression Loss: 1463.7567, Test Regression Loss: 1470.4822
Training Classification Loss: 10836.4331, Validation Classification Loss: 1463.7567, Test Classification Loss: 1470.4822
Training RMSE: 845.3831, Validation RMSE: 1309.9773, Test RMSE: 1315.4745
Training PCC: 0.0003, Validation PCC: 0.0002, Test PCC: 0.0069
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 63108.8405, Validation Total Loss: 1334.1256, Test Total Loss: 1337.8837
Training Regression Loss: 1785.7624, Validation Regression Loss: 1334.1256, Test Regression Loss: 1337.8837
Training Classification Loss: 15330.7695, Validation Classification Loss: 1334.1256, Test Classification Loss: 1337.8837
Training RMSE: 1456.6407, Validation RMSE: 847.6042, Test RMSE: 846.0442
Training PCC: -0.0012, Validation PCC: -0.0008, Test PCC: -0.0053
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 89130.8239, Validation Total Loss: 1677.5506, Test Total Loss: 1682.5321
Training Regression Loss: 1781.7788, Validation Regression Loss: 1677.5506, Test Regression Loss: 1682.5321
Training Classification Loss: 21837.2613, Validation Classification Loss: 1677.5506, Test Classification Loss: 1682.5321
Training RMSE: 1240.0642, Validation RMSE: 1450.3433, Test RMSE: 1455.9930
Training PCC: -0.0042, Validation PCC: -0.0021, Test PCC: 0.0028
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 114593.7198, Validation Total Loss: 2165.9458, Test Total Loss: 2170.4799
Training Regression Loss: 3118.3447, Validation Regression Loss: 2165.9458, Test Regression Loss: 2170.4799
Training Classification Loss: 27868.8438, Validation Classification Loss: 2165.9458, Test Classification Loss: 2170.4799
Training RMSE: 1984.3996, Validation RMSE: 1794.4679, Test RMSE: 1795.2264
Training PCC: 0.0035, Validation PCC: 0.0148, Test PCC: -0.0018
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 137577.2685, Validation Total Loss: 4383.9097, Test Total Loss: 4385.1875
Training Regression Loss: 5043.9171, Validation Regression Loss: 4383.9097, Test Regression Loss: 4385.1875
Training Classification Loss: 33133.3379, Validation Classification Loss: 4383.9097, Test Classification Loss: 4385.1875
Training RMSE: 3804.9376, Validation RMSE: 4088.9325, Test RMSE: 4087.8753
Training PCC: -0.0001, Validation PCC: 0.0048, Test PCC: -0.0079
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 167373.3733, Validation Total Loss: 3681.1093, Test Total Loss: 3684.0292
Training Regression Loss: 3445.9341, Validation Regression Loss: 3681.1093, Test Regression Loss: 3684.0292
Training Classification Loss: 40981.8598, Validation Classification Loss: 3681.1093, Test Classification Loss: 3684.0292
Training RMSE: 2477.1043, Validation RMSE: 3239.7380, Test RMSE: 3245.0284
Training PCC: 0.0011, Validation PCC: -0.0088, Test PCC: -0.0004
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 183822.8261, Validation Total Loss: 1224.7477, Test Total Loss: 1228.3905
Training Regression Loss: 4648.0215, Validation Regression Loss: 1224.7477, Test Regression Loss: 1228.3905
Training Classification Loss: 44793.7013, Validation Classification Loss: 1224.7477, Test Classification Loss: 1228.3905
Training RMSE: 3558.4513, Validation RMSE: 795.5025, Test RMSE: 794.1025
Training PCC: 0.0043, Validation PCC: 0.0105, Test PCC: -0.0017
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 228578.5906, Validation Total Loss: 3874.9325, Test Total Loss: 3870.8957
Training Regression Loss: 4949.4876, Validation Regression Loss: 3874.9325, Test Regression Loss: 3870.8957
Training Classification Loss: 55907.2758, Validation Classification Loss: 3874.9325, Test Classification Loss: 3870.8957
Training RMSE: 3829.1210, Validation RMSE: 2313.9421, Test RMSE: 2313.2921
Training PCC: -0.0023, Validation PCC: 0.0067, Test PCC: -0.0067
Checkpoint saved for epoch 11
Stopping early after 11 epochs
Total training time: 1634.44 seconds
Finished training for TeacherModel_DomainInvariant_alpha_4_wl100_ol75.
Running model: TeacherModel_DomainInvariant_alpha_8_wl100_ol75 with alpha: 8
Starting from scratch.


Epoch 1/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 1, Training Total Loss: 1305.1802, Validation Total Loss: 43.0915, Test Total Loss: 47.2854
Training Regression Loss: 46.2849, Validation Regression Loss: 43.0915, Test Regression Loss: 47.2854
Training Classification Loss: 157.3619, Validation Classification Loss: 43.0915, Test Classification Loss: 47.2854
Training RMSE: 43.7896, Validation RMSE: 40.7862, Test RMSE: 43.5335
Training PCC: -0.0028, Validation PCC: 0.0006, Test PCC: -0.0007
Checkpoint saved for epoch 1


Epoch 2/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 2, Training Total Loss: 15844.2296, Validation Total Loss: 756.2087, Test Total Loss: 759.3465
Training Regression Loss: 219.0974, Validation Regression Loss: 756.2087, Test Regression Loss: 759.3465
Training Classification Loss: 1953.1415, Validation Classification Loss: 756.2087, Test Classification Loss: 759.3465
Training RMSE: 187.8279, Validation RMSE: 533.4611, Test RMSE: 530.7109
Training PCC: 0.0014, Validation PCC: -0.0019, Test PCC: 0.0054
Checkpoint saved for epoch 2


Epoch 3/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 3, Training Total Loss: 46255.5674, Validation Total Loss: 934.6620, Test Total Loss: 941.3916
Training Regression Loss: 759.4389, Validation Regression Loss: 934.6620, Test Regression Loss: 941.3916
Training Classification Loss: 5687.0161, Validation Classification Loss: 934.6620, Test Classification Loss: 941.3916
Training RMSE: 644.9594, Validation RMSE: 749.2873, Test RMSE: 751.7044
Training PCC: -0.0047, Validation PCC: 0.0016, Test PCC: 0.0002
Checkpoint saved for epoch 3


Epoch 4/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 4, Training Total Loss: 96519.7866, Validation Total Loss: 1800.2333, Test Total Loss: 1807.5625
Training Regression Loss: 1486.1266, Validation Regression Loss: 1800.2333, Test Regression Loss: 1807.5625
Training Classification Loss: 11879.2075, Validation Classification Loss: 1800.2333, Test Classification Loss: 1807.5625
Training RMSE: 1166.8902, Validation RMSE: 1490.9228, Test RMSE: 1495.8619
Training PCC: -0.0003, Validation PCC: -0.0102, Test PCC: 0.0099
Checkpoint saved for epoch 4


Epoch 5/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 5, Training Total Loss: 150126.7251, Validation Total Loss: 1294.8973, Test Total Loss: 1299.6316
Training Regression Loss: 1642.5486, Validation Regression Loss: 1294.8973, Test Regression Loss: 1299.6316
Training Classification Loss: 18560.5221, Validation Classification Loss: 1294.8973, Test Classification Loss: 1299.6316
Training RMSE: 1282.1694, Validation RMSE: 827.6532, Test RMSE: 831.8520
Training PCC: 0.0034, Validation PCC: 0.0138, Test PCC: 0.0364
Checkpoint saved for epoch 5


Epoch 6/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 6, Training Total Loss: 212302.9628, Validation Total Loss: 1053.6000, Test Total Loss: 1052.7287
Training Regression Loss: 1866.6746, Validation Regression Loss: 1053.6000, Test Regression Loss: 1052.7287
Training Classification Loss: 26304.5360, Validation Classification Loss: 1053.6000, Test Classification Loss: 1052.7287
Training RMSE: 1218.4608, Validation RMSE: 720.3582, Test RMSE: 714.0852
Training PCC: -0.0070, Validation PCC: -0.0250, Test PCC: -0.0244
Checkpoint saved for epoch 6


Epoch 7/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 7, Training Total Loss: 270182.2460, Validation Total Loss: 10603.5657, Test Total Loss: 10605.3601
Training Regression Loss: 3520.6821, Validation Regression Loss: 10603.5657, Test Regression Loss: 10605.3601
Training Classification Loss: 33332.6955, Validation Classification Loss: 10603.5657, Test Classification Loss: 10605.3601
Training RMSE: 2677.1559, Validation RMSE: 7644.2074, Test RMSE: 7650.1740
Training PCC: -0.0018, Validation PCC: -0.0221, Test PCC: -0.0339
Checkpoint saved for epoch 7


Epoch 8/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 8, Training Total Loss: 295564.0506, Validation Total Loss: 3268.4054, Test Total Loss: 3261.9047
Training Regression Loss: 4695.6750, Validation Regression Loss: 3268.4054, Test Regression Loss: 3261.9047
Training Classification Loss: 36358.5470, Validation Classification Loss: 3268.4054, Test Classification Loss: 3261.9047
Training RMSE: 3883.1851, Validation RMSE: 3075.4730, Test RMSE: 3069.5392
Training PCC: 0.0002, Validation PCC: -0.0102, Test PCC: -0.0096
Checkpoint saved for epoch 8


Epoch 9/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 9, Training Total Loss: 358775.9338, Validation Total Loss: 3192.0979, Test Total Loss: 3196.6396
Training Regression Loss: 3553.0644, Validation Regression Loss: 3192.0979, Test Regression Loss: 3196.6396
Training Classification Loss: 44402.8587, Validation Classification Loss: 3192.0979, Test Classification Loss: 3196.6396
Training RMSE: 2411.4980, Validation RMSE: 1983.6934, Test RMSE: 1988.9080
Training PCC: 0.0007, Validation PCC: -0.0076, Test PCC: 0.0247
Checkpoint saved for epoch 9


Epoch 10/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 10, Training Total Loss: 426810.1327, Validation Total Loss: 3667.6543, Test Total Loss: 3667.8897
Training Regression Loss: 4396.0631, Validation Regression Loss: 3667.6543, Test Regression Loss: 3667.8897
Training Classification Loss: 52801.7587, Validation Classification Loss: 3667.6543, Test Classification Loss: 3667.8897
Training RMSE: 3084.1037, Validation RMSE: 2261.4525, Test RMSE: 2260.3864
Training PCC: -0.0029, Validation PCC: 0.0220, Test PCC: 0.0110
Checkpoint saved for epoch 10


Epoch 11/50 Training:   0%|          | 0/327 [00:00<?, ?it/s]

Epoch: 11, Training Total Loss: 444482.0074, Validation Total Loss: 3297.5413, Test Total Loss: 3292.2595
Training Regression Loss: 4909.7258, Validation Regression Loss: 3297.5413, Test Regression Loss: 3292.2595
Training Classification Loss: 54946.5351, Validation Classification Loss: 3297.5413, Test Classification Loss: 3292.2595
Training RMSE: 3342.0212, Validation RMSE: 2362.8825, Test RMSE: 2361.7018
Training PCC: 0.0069, Validation PCC: 0.0047, Test PCC: 0.0291
Checkpoint saved for epoch 11
Stopping early after 11 epochs
Total training time: 1661.31 seconds
Finished training for TeacherModel_DomainInvariant_alpha_8_wl100_ol75.


In [24]:
# @title Plotting
import matplotlib.pyplot as plt
def load_most_recent_checkpoint(folder):
    """
    Loads the most recent checkpoint from the given folder and returns it.
    """
    checkpoint_files = [f for f in os.listdir(folder) if f.endswith('.pth')]
    if not checkpoint_files:
        print(f"No checkpoints found in {folder}")
        return None

    # Sort the checkpoint files by epoch (assuming filenames include the epoch number)
    checkpoint_files.sort(key=lambda f: int(f.split('_epoch_')[-1].split('.')[0]))

    # Load the most recent checkpoint
    latest_checkpoint = checkpoint_files[-1]
    checkpoint_path = os.path.join(folder, latest_checkpoint)
    checkpoint = torch.load(checkpoint_path)

    print(f"Loaded checkpoint: {latest_checkpoint} from {folder}")

    # Return the loaded checkpoint, which contains model state, optimizer state, and history
    return checkpoint

def plot_metrics_subplots(aggregated_metrics, folder):
    """Plots train/val/test losses, PCCs, and RMSEs in subplots."""
    epochs = list(range(1, len(aggregated_metrics['train_losses']) + 1))
    line_styles = ['-', '--', '-.', ':']  # Define different line styles for channels

    # New color palette
    train_color = '#17becf'  # Teal blue for Train
    val_color = '#bcbd22'    # Mustard yellow for Val
    test_color = '#e377c2'   # Pastel magenta for Test

    # Create subplots for losses, PCCs, and RMSEs
    fig, axes = plt.subplots(3, 2, figsize=(12, 14))
    fig.suptitle(f'Metrics over Epochs for {folder}', fontsize=16)

    # Plot Losses (subplot 1, 1)
    axes[0, 0].plot(epochs, aggregated_metrics['train_losses'], label='Train Loss', color=train_color, linestyle=line_styles[0])
    axes[0, 0].plot(epochs, aggregated_metrics['val_losses'], label='Val Loss', color=val_color, linestyle=line_styles[0])
    axes[0, 0].plot(epochs, aggregated_metrics['test_losses'], label='Test Loss', color=test_color, linestyle=line_styles[0])
    axes[0, 0].set_title('Losses over Epochs')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()

    # Plot Average PCCs (already averaged, so no need to take mean again) (subplot 1, 2)
    axes[0, 1].plot(epochs, aggregated_metrics['train_pccs'], label='Avg Train PCC', color=train_color, linestyle='-')
    axes[0, 1].plot(epochs, aggregated_metrics['val_pccs'], label='Avg Val PCC', color=val_color, linestyle='-')
    axes[0, 1].plot(epochs, aggregated_metrics['test_pccs'], label='Avg Test PCC', color=test_color, linestyle='-')
    axes[0, 1].set_title('Average PCC over Epochs')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('PCC')
    axes[0, 1].legend()

    # Plot PCCs for each channel (subplot 2, 1)
    for i, style in enumerate(line_styles):
        if len(aggregated_metrics['train_pccs_channelwise'][0]) > i:  # Ensure enough channels exist
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['train_pccs_channelwise']], label=f'Train PCC (Ch {i+1})', color=train_color, linestyle=style)
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['val_pccs_channelwise']], label=f'Val PCC (Ch {i+1})', color=val_color, linestyle=style)
            axes[1, 0].plot(epochs, [pcc[i] for pcc in aggregated_metrics['test_pccs_channelwise']], label=f'Test PCC (Ch {i+1})', color=test_color, linestyle=style)
    axes[1, 0].set_title('PCCs over Epochs (Per Channel)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('PCC')
    axes[1, 0].legend()

    # Plot Average RMSEs (already averaged, so no need to take mean again) (subplot 2, 2)
    axes[1, 1].plot(epochs, aggregated_metrics['train_rmses'], label='Avg Train RMSE', color=train_color, linestyle='-')
    axes[1, 1].plot(epochs, aggregated_metrics['val_rmses'], label='Avg Val RMSE', color=val_color, linestyle='-')
    axes[1, 1].plot(epochs, aggregated_metrics['test_rmses'], label='Avg Test RMSE', color=test_color, linestyle='-')
    axes[1, 1].set_title('Average RMSE over Epochs')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('RMSE')
    axes[1, 1].legend()

    # Plot RMSEs for each channel (subplot 3, 1)
    for i, style in enumerate(line_styles):
        if len(aggregated_metrics['train_rmses_channelwise'][0]) > i:
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['train_rmses_channelwise']], label=f'Train RMSE (Ch {i+1})', color=train_color, linestyle=style)
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['val_rmses_channelwise']], label=f'Val RMSE (Ch {i+1})', color=val_color, linestyle=style)
            axes[2, 0].plot(epochs, [rmse[i] for rmse in aggregated_metrics['test_rmses_channelwise']], label=f'Test RMSE (Ch {i+1})', color=test_color, linestyle=style)
    axes[2, 0].set_title('RMSEs over Epochs (Per Channel)')
    axes[2, 0].set_xlabel('Epoch')
    axes[2, 0].set_ylabel('RMSE')
    axes[2, 0].legend()

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to make room for the suptitle
    plt.show()

def plot_comparison_bar_graphs_with_limits(aggregated_metrics_dict):
    """
    Plots two bar graphs comparing the best RMSE and PCC for each model, showing corresponding PCC/RMSE from the same epoch,
    and plots average test RMSEs across epochs for each model.
    """

    # Dictionary to hold the best RMSE and PCC along with their associated metrics at the same epoch
    best_metrics = {}

    # Iterate through the aggregated metrics and extract the best RMSE and its corresponding PCC, and vice versa
    for model_name, metrics in aggregated_metrics_dict.items():
        # Extract pre-averaged RMSEs and PCCs for the test data
        test_rmses = metrics['test_rmses']  # Pre-averaged RMSE values
        test_pccs = metrics['test_pccs']    # Pre-averaged PCC values
        val_losses = metrics['val_losses']  # Pre-averaged validation losses

        # Find the epoch index for the best RMSE and the corresponding PCC at that epoch
        best_rmse_epoch = np.argmin(val_losses)
        best_rmse = test_rmses[best_rmse_epoch]
        corresponding_pcc_for_best_rmse = test_pccs[best_rmse_epoch]

        # Find the epoch index for the best PCC and the corresponding RMSE at that epoch
        best_pcc_epoch = np.argmax(test_pccs)
        best_pcc = test_pccs[best_pcc_epoch]
        corresponding_rmse_for_best_pcc = test_rmses[best_pcc_epoch]

        # Store in a dictionary for plotting
        best_metrics[model_name] = {
            'best_rmse': best_rmse,
            'corresponding_pcc_for_best_rmse': corresponding_pcc_for_best_rmse,
            'best_pcc': best_pcc,
            'corresponding_rmse_for_best_pcc': corresponding_rmse_for_best_pcc
        }

    # Convert to lists for sorting
    models = list(best_metrics.keys())
    best_rmses = [best_metrics[model]['best_rmse'] for model in models]
    corresponding_pccs = [best_metrics[model]['corresponding_pcc_for_best_rmse'] for model in models]
    best_pccs = [best_metrics[model]['best_pcc'] for model in models]
    corresponding_rmses = [best_metrics[model]['corresponding_rmse_for_best_pcc'] for model in models]

    # Sort by RMSE and PCC
    sorted_by_rmse = sorted(zip(models, best_rmses, corresponding_pccs), key=lambda x: x[1], reverse=False)
    sorted_by_pcc = sorted(zip(models, best_pccs, corresponding_rmses), key=lambda x: x[1], reverse=True)

    # Unpack sorted values
    sorted_models_rmse, sorted_rmses, sorted_pccs_for_rmse = zip(*sorted_by_rmse)
    sorted_models_pcc, sorted_pccs, sorted_rmses_for_pcc = zip(*sorted_by_pcc)

    # Ensure consistent colors across both graphs
    colors = sns.color_palette("husl", len(models))
    model_colors_rmse = {model: colors[i] for i, model in enumerate(sorted_models_rmse)}
    model_colors_pcc = {model: model_colors_rmse[model] for model in sorted_models_pcc}  # Keep the colors consistent

    # Plot RMSE Bar Graph (With Corresponding PCC values)
    plt.figure(figsize=(10, 6))
    bars_rmse = plt.bar(sorted_models_rmse, sorted_rmses, color=[model_colors_rmse[model] for model in sorted_models_rmse])
    plt.title('Best RMSE per Model (Sorted)')
    plt.xlabel('Model')
    plt.ylabel('Best RMSE')
    plt.ylim(0, 25)  # Set an appropriate y-limit for RMSEs
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Add the actual RMSE values and corresponding PCC on top of the bars
    for i, bar in enumerate(bars_rmse):
        yval_rmse = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval_rmse + 0.5, f'{yval_rmse:.2f}', ha='center', va='bottom')
        plt.text(bar.get_x() + bar.get_width()/2, yval_rmse - 2, f'PCC: {sorted_pccs_for_rmse[i]:.2f}', ha='center', va='top', color='black')

    plt.show()

    # Plot PCC Bar Graph (With Corresponding RMSE values)
    plt.figure(figsize=(10, 6))
    bars_pcc = plt.bar(sorted_models_pcc, sorted_pccs, color=[model_colors_pcc[model] for model in sorted_models_pcc])
    plt.title('Best PCC per Model (Sorted)')
    plt.xlabel('Model')
    plt.ylabel('Best PCC')
    plt.ylim(0, 1)  # Set an appropriate y-limit for PCCs
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Add the actual PCC values and corresponding RMSE on top of the bars
    for i, bar in enumerate(bars_pcc):
        yval_pcc = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval_pcc + 0.02, f'{yval_pcc:.2f}', ha='center', va='bottom')
        plt.text(bar.get_x() + bar.get_width()/2, yval_pcc - 0.08, f'RMSE: {sorted_rmses_for_pcc[i]:.2f}', ha='center', va='top', color='black')

    plt.show()

    # Plot the average test RMSEs across epochs for each model
    plt.figure(figsize=(10, 6))
    for i, model in enumerate(models):
        avg_test_rmses = aggregated_metrics_dict[model]['test_rmses']  # Directly plot pre-averaged RMSEs
        plt.plot(avg_test_rmses, label=f'{model} Average Test RMSE', color=colors[i])

    plt.title('Average Test RMSE per Epoch for All Models')
    plt.xlabel('Epoch')
    plt.ylabel('Average RMSE')
    plt.legend()
    plt.tight_layout()
    plt.show()


# Initialize an empty dictionary to hold metrics for all models
aggregated_metrics_dict = {}

# Loop over each model and load the checkpoint
for model_name, model_config in model_configs.items():
    # Construct the folder path based on the filename
    folder = f"/content/MyDrive/MyDrive/models/{model_name}"
    checkpoint = load_most_recent_checkpoint(folder)

    if checkpoint:
        aggregated_metrics = checkpoint['history']

        # Store the metrics for comparison later
        aggregated_metrics_dict[model_name] = {
            'test_rmses': aggregated_metrics['test_rmses'],
            'test_pccs': aggregated_metrics['test_pccs'],
            'train_losses': aggregated_metrics['train_losses'],
            'val_losses': aggregated_metrics['val_losses'],
        }

        # Plot metrics for this model
        plot_metrics_subplots(aggregated_metrics, model_name)

# After collecting all model metrics, plot the comparison bar graphs
if aggregated_metrics_dict:
    plot_comparison_bar_graphs_with_limits(aggregated_metrics_dict)


Output hidden; open in https://colab.research.google.com to view.