In [None]:


#mount drive
from google.colab import drive
drive.mount('/content/MyDrive')
import seaborn as sns
sns.set_theme("paper")

Mounted at /content/MyDrive


In [None]:
import os
import torch
import joblib
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from scipy.io import wavfile
from sklearn.metrics import mean_squared_error
import statistics
from scipy.signal import savgol_filter, butter, filtfilt

class Config:
    def __init__(self, **kwargs):
        self.batch_size = kwargs.get('batch_size', 64)
        self.epochs = kwargs.get('epochs', 50)
        self.lr = kwargs.get('lr', 0.001)
        self.channels_imu_acc = kwargs.get('channels_imu_acc', [])
        self.channels_imu_acc_test = kwargs.get('channels_imu_acc_test', [])
        self.channels_imu_gyr_test = kwargs.get('channels_imu_gyr_test', [])
        self.channels_imu_gyr = kwargs.get('channels_imu_gyr', [])
        self.channels_joints = kwargs.get('channels_joints', [])
        self.channels_emg = kwargs.get('channels_emg', [])
        self.seed = kwargs.get('seed', 42)
        self.data_folder_name = kwargs.get('data_folder_name', 'default_data_folder_name')
        self.dataset_root = kwargs.get('dataset_root', 'default_dataset_root')
        self.dataset_train_name = kwargs.get('dataset_train_name', 'train')
        self.dataset_test_name = kwargs.get('dataset_test_name', 'test')
        self.window_length = kwargs.get('window_length', 100)
        self.window_overlap = kwargs.get('window_overlap', 0)
        self.imu_transforms = kwargs.get('imu_transforms', [])
        self.joint_transforms = kwargs.get('joint_transforms', [])
        self.emg_transforms = kwargs.get('emg_transforms', [])
        self.input_format = kwargs.get('input_format', 'csv')
        self.train_subjects = kwargs.get('train_subjects', [])
        self.test_subjects = kwargs.get('test_subjects', [])

        self.dataset_name = self.generate_dataset_name()

    def generate_dataset_name(self):
        name = f"dataset_wl{self.window_length}_ol{self.window_overlap}_train{self.train_subjects}_test{self.test_subjects}"
        return name

In [None]:
class ImuJointPairDataset(Dataset):
    def __init__(self, config, split='train'):
        self.config = config
        self.split = split
        self.input_format = config.input_format
        self.channels_imu_acc = config.channels_imu_acc
        self.channels_imu_acc_test = config.channels_imu_acc_test
        self.channels_imu_gyr = config.channels_imu_gyr
        self.channels_imu_gyr_test = config.channels_imu_gyr_test
        self.channels_joints = config.channels_joints
        self.channels_emg = config.channels_emg

        dataset_name = self.config.dataset_name
        self.root_dir_train = os.path.join(self.config.dataset_root, dataset_name, self.config.dataset_train_name)
        self.root_dir_test = os.path.join(self.config.dataset_root, dataset_name, self.config.dataset_test_name)

        train_info_path = os.path.join(self.config.dataset_root, dataset_name, "train_info.csv")
        test_info_path = os.path.join(self.config.dataset_root, dataset_name, "test_info.csv")
        self.data = pd.read_csv(train_info_path) if split == 'train' else pd.read_csv(test_info_path)

        self.scaler_save_path = os.path.join(self.config.dataset_root, dataset_name, "scaler.pkl")
        self.scaler = joblib.load(self.scaler_save_path) if os.path.exists(self.scaler_save_path) else None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.split == "train":
            file_path = os.path.join(self.root_dir_train, self.data.iloc[idx, 0])
        else:
            file_path = os.path.join(self.root_dir_test, self.data.iloc[idx, 0])

        if self.input_format == "wav":
            combined_data, _ = get_data_from_wav_file(file_path)
        elif self.input_format == "csv":
            combined_data = pd.read_csv(file_path)
        else:
            raise ValueError("Unsupported input format: {}".format(self.input_format))

        imu_data_acc, imu_data_gyr, joint_data, emg_data = self._extract_and_transform(combined_data)
        windows = self._apply_windowing(imu_data_acc, imu_data_gyr, joint_data, emg_data, self.config.window_length, self.config.window_overlap)

        acc_concat = np.concatenate([w[0] for w in windows], axis=0)
        gyr_concat = np.concatenate([w[1] for w in windows], axis=0)
        joint_concat = np.concatenate([w[2] for w in windows], axis=0)
        emg_concat = np.concatenate([w[3] for w in windows], axis=0)

        return acc_concat, gyr_concat, joint_concat, emg_concat

    def _extract_and_transform(self, combined_data):
        imu_data_acc = self._extract_channels(combined_data, self.channels_imu_acc)
        imu_data_gyr = self._extract_channels(combined_data, self.channels_imu_gyr)
        joint_data = self._extract_channels(combined_data, self.channels_joints)
        emg_data = self._extract_channels(combined_data, self.channels_emg)

        combined_data = np.concatenate([imu_data_acc, imu_data_gyr, joint_data, emg_data], axis=1)
        scaled_data = combined_data

        imu_data_acc = scaled_data[:, :imu_data_acc.shape[1]]
        imu_data_gyr = scaled_data[:, imu_data_acc.shape[1]:imu_data_acc.shape[1] + imu_data_gyr.shape[1]]
        joint_data = scaled_data[:, imu_data_acc.shape[1] + imu_data_gyr.shape[1]:imu_data_acc.shape[1] + imu_data_gyr.shape[1] + joint_data.shape[1]]
        emg_data = scaled_data[:, imu_data_acc.shape[1] + imu_data_gyr.shape[1] + joint_data.shape[1]:]

        imu_data_acc = self.apply_transforms(imu_data_acc, self.config.imu_transforms)
        imu_data_gyr = self.apply_transforms(imu_data_gyr, self.config.imu_transforms)
        joint_data = self.apply_transforms(joint_data, self.config.joint_transforms)
        emg_data = self.apply_transforms(emg_data, self.config.emg_transforms)

        return imu_data_acc, imu_data_gyr, joint_data, emg_data

    def _extract_channels(self, combined_data, channels):
        if isinstance(channels, slice):
            return combined_data.iloc[:, channels].values if self.input_format == "csv" else combined_data[:, channels]
        else:
            return combined_data[channels].values if self.input_format == "csv" else combined_data[:, channels]

    def _apply_windowing(self, imu_data_acc, imu_data_gyr, joint_data, emg_data, window_length, window_overlap):
        num_samples = imu_data_acc.shape[0]
        step = window_length - window_overlap
        windows = []

        for start in range(0, num_samples - window_length + 1, step):
            end = start + window_length
            window = (
                imu_data_acc[start:end],
                imu_data_gyr[start:end],
                joint_data[start:end],
                emg_data[start:end]
            )
            windows.append(window)

        return windows

    def apply_transforms(self, data, transforms):
        for transform in transforms:
            data = transform(data)
        data = torch.tensor(data, dtype=torch.float32)
        return data


In [None]:
#prediction function
def RMSE_prediction(yhat_4,test_y, output_dim,print_losses=True):

  s1=yhat_4.shape[0]*yhat_4.shape[1]

  test_o=test_y.reshape((s1,output_dim))
  yhat=yhat_4.reshape((s1,output_dim))




  y_1_no=yhat[:,0]
  y_2_no=yhat[:,1]
  y_3_no=yhat[:,2]

  y_1=y_1_no
  y_2=y_2_no
  y_3=y_3_no


  y_test_1=test_o[:,0]
  y_test_2=test_o[:,1]
  y_test_3=test_o[:,2]



  cutoff=6
  fs=200
  order=4

  nyq = 0.5 * fs
  ## filtering data ##
  def butter_lowpass_filter(data, cutoff, fs, order):
      normal_cutoff = cutoff / nyq
      # Get the filter coefficients
      b, a = butter(order, normal_cutoff, btype='low', analog=False)
      y = filtfilt(b, a, data)
      return y



  Z_1=y_1
  Z_2=y_2
  Z_3=y_3



  ###calculate RMSE

  rmse_1 =((np.sqrt(mean_squared_error(y_test_1,y_1))))
  rmse_2 =((np.sqrt(mean_squared_error(y_test_2,y_2))))
  rmse_3 =((np.sqrt(mean_squared_error(y_test_3,y_3))))





  p_1=np.corrcoef(y_1, y_test_1)[0, 1]
  p_2=np.corrcoef(y_2, y_test_2)[0, 1]
  p_3=np.corrcoef(y_3, y_test_3)[0, 1]




              ### Correlation ###
  p=np.array([p_1,p_2,p_3])
  #,p_4,p_5,p_6,p_7])




      #### Mean and standard deviation ####

  rmse=np.array([rmse_1,rmse_2,rmse_3])
  #,rmse_4,rmse_5,rmse_6,rmse_7])

      #### Mean and standard deviation ####
  m=statistics.mean(rmse)
  SD=statistics.stdev(rmse)


  m_c=statistics.mean(p)
  SD_c=statistics.stdev(p)


  if print_losses:
    print(rmse_1)
    print(rmse_2)
    print(rmse_3)
    print("\n")
    print(p_1)
    print(p_2)
    print(p_3)
    print('Mean: %.3f' % m,'+/- %.3f' %SD)
    print('Mean: %.3f' % m_c,'+/- %.3f' %SD_c)

  return rmse, p, Z_1,Z_2,Z_3
  #,Z_4,Z_5,Z_6,Z_7



############################################################################################################################################################################################################################################################################################################################################################################################################################################################################


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class Encoder_1(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_1, self).__init__()
        self.lstm_1 = nn.LSTM(input_dim, 128, bidirectional=True, batch_first=True, dropout=0.0)
        self.lstm_2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True, dropout=0.0)
        self.flatten=nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1=nn.Dropout(dropout)
        self.dropout_2=nn.Dropout(dropout)


    def forward(self, x):
        out_1, _ = self.lstm_1(x)
        out_1=self.dropout_1(out_1)
        out_2, _ = self.lstm_2(out_1)
        out_2=self.dropout_2(out_2)

        return out_2




class Encoder_2(nn.Module):
    def __init__(self, input_dim, dropout):
        super(Encoder_2, self).__init__()
        self.lstm_1 = nn.GRU(input_dim, 128, bidirectional=True, batch_first=True, dropout=0.0)
        self.lstm_2 = nn.GRU(256, 64, bidirectional=True, batch_first=True, dropout=0.0)
        self.flatten=nn.Flatten()
        self.fc = nn.Linear(128, 32)
        self.dropout_1=nn.Dropout(dropout)
        self.dropout_2=nn.Dropout(dropout)


    def forward(self, x):
        out_1, _ = self.lstm_1(x)
        out_1=self.dropout_1(out_1)
        out_2, _ = self.lstm_2(out_1)
        out_2=self.dropout_2(out_2)

        return out_2


class GatingModule(nn.Module):
    def __init__(self, input_size):
        super(GatingModule, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(2*input_size, input_size),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        # Apply gating mechanism
        gate_output = self.gate(torch.cat((input1,input2),dim=-1))

        # Scale the inputs based on the gate output
        gated_input1 = input1 * gate_output
        gated_input2 = input2 * (1 - gate_output)

        # Combine the gated inputs
        output = gated_input1 + gated_input2
        return output
#variable w needs to be checked for correct value, stand-in value used
class teacher(nn.Module):
    def __init__(self, input_acc, input_gyr, input_emg, drop_prob=0.25, w=100):
        super(teacher, self).__init__()

        self.w=w
        self.encoder_1_acc=Encoder_1(input_acc, drop_prob)
        self.encoder_1_gyr=Encoder_1(input_gyr, drop_prob)
        self.encoder_1_emg=Encoder_1(input_emg, drop_prob)

        self.encoder_2_acc=Encoder_2(input_acc, drop_prob)
        self.encoder_2_gyr=Encoder_2(input_gyr, drop_prob)
        self.encoder_2_emg=Encoder_2(input_emg, drop_prob)

        self.BN_acc= nn.BatchNorm1d(input_acc, affine=False)
        self.BN_gyr= nn.BatchNorm1d(input_gyr, affine=False)
        self.BN_emg= nn.BatchNorm1d(input_emg, affine=False)


        self.fc = nn.Linear(2*3*128+128,3)
        self.dropout=nn.Dropout(p=0.05)

        self.gate_1=GatingModule(128)
        self.gate_2=GatingModule(128)
        self.gate_3=GatingModule(128)

        self.fc_kd = nn.Linear(3*128, 2*128)

               # Define the gating network
        self.weighted_feat = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid())

        self.attention=nn.MultiheadAttention(3*128,4,batch_first=True)
        self.gating_net = nn.Sequential(nn.Linear(128*3, 3*128), nn.Sigmoid())
        self.gating_net_1 = nn.Sequential(nn.Linear(2*3*128+128, 2*3*128+128), nn.Sigmoid())

        self.pool = nn.MaxPool1d(kernel_size=2)


    def forward(self, x_acc, x_gyr, x_emg):

        x_acc_1=x_acc.view(x_acc.size(0)*x_acc.size(1),x_acc.size(-1))
        x_gyr_1=x_gyr.view(x_gyr.size(0)*x_gyr.size(1),x_gyr.size(-1))
        x_emg_1=x_emg.view(x_emg.size(0)*x_emg.size(1),x_emg.size(-1))

        x_acc_1=self.BN_acc(x_acc_1)
        x_gyr_1=self.BN_gyr(x_gyr_1)
        x_emg_1=self.BN_emg(x_emg_1)

        x_acc_2=x_acc_1.view(-1, self.w, x_acc_1.size(-1))
        x_gyr_2=x_gyr_1.view(-1, self.w, x_gyr_1.size(-1))
        x_emg_2=x_emg_1.view(-1, self.w, x_emg_1.size(-1))

        x_acc_1=self.encoder_1_acc(x_acc_2)
        x_gyr_1=self.encoder_1_gyr(x_gyr_2)
        x_emg_1=self.encoder_1_emg(x_emg_2)

        x_acc_2=self.encoder_2_acc(x_acc_2)
        x_gyr_2=self.encoder_2_gyr(x_gyr_2)
        x_emg_2=self.encoder_2_emg(x_emg_2)

        # x_acc=torch.cat((x_acc_1,x_acc_2),dim=-1)
        # x_gyr=torch.cat((x_gyr_1,x_gyr_2),dim=-1)
        # x_emg=torch.cat((x_emg_1,x_emg_2),dim=-1)

        x_acc=self.gate_1(x_acc_1,x_acc_2)
        x_gyr=self.gate_2(x_gyr_1,x_gyr_2)
        x_emg=self.gate_3(x_emg_1,x_emg_2)

        x=torch.cat((x_acc,x_gyr,x_emg),dim=-1)
        x_kd=self.fc_kd(x)


        out_1, attn_output_weights=self.attention(x,x,x)

        gating_weights = self.gating_net(x)
        out_2=gating_weights*x

        weights_1 = self.weighted_feat(x[:,:,0:128])
        weights_2 = self.weighted_feat(x[:,:,128:2*128])
        weights_3 = self.weighted_feat(x[:,:,2*128:3*128])
        x_1=weights_1*x[:,:,0:128]
        x_2=weights_2*x[:,:,128:2*128]
        x_3=weights_3*x[:,:,2*128:3*128]
        out_3=x_1+x_2+x_3

        out=torch.cat((out_1,out_2,out_3),dim=-1)

        gating_weights_1 = self.gating_net_1(out)
        out=gating_weights_1*out

        out=self.fc(out)

        #print(out.shape)
        return out,x_kd


class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
    def forward(self, output, target):
        loss = torch.sqrt(torch.mean((output - target) ** 2))
        return loss




In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import time
from tqdm.notebook import tqdm
from sklearn.metrics import mean_squared_error
from scipy.signal import butter, filtfilt
import statistics

# Save checkpoint function
def save_checkpoint(model, optimizer, epoch, filename, train_loss, val_loss, test_loss=None, channelwise_metrics=None):
    """Saves the model, optimizer state, and losses (including channel-wise) to a checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_channelwise_metrics': channelwise_metrics['train'],
        'val_channelwise_metrics': channelwise_metrics['val'],
    }
    if test_loss is not None:
        checkpoint['test_loss'] = test_loss
        checkpoint['test_channelwise_metrics'] = channelwise_metrics['test']



    torch.save(checkpoint, filename)
    print(f"Checkpoint saved for epoch {epoch + 1}")

# Evaluation function
def evaluate_model(device, model, loader, criterion):
    """Runs evaluation on the validation or test set."""
    model.eval()
    total_loss = 0.0
    total_pcc = np.zeros(len(config.channels_joints))
    total_rmse = np.zeros(len(config.channels_joints))

    with torch.no_grad():
        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(loader):
            output, _ = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())
            loss = criterion(output, target.to(device).float())
            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            total_loss += loss.item()
            total_pcc += batch_pcc
            total_rmse += batch_rmse

    avg_loss = total_loss / len(loader)
    avg_pcc = total_pcc / len(loader)
    avg_rmse = total_rmse / len(loader)

    return avg_loss, avg_pcc, avg_rmse

# Training function
def train_teacher(device, train_loader, val_loader, test_loader, learn_rate, epochs, model, filename):
    if torch.cuda.is_available():
        model.cuda()

    criterion = RMSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    train_losses = []
    val_losses = []
    test_losses = []

    train_pccs = []
    val_pccs = []
    test_pccs = []

    train_rmses = []
    val_rmses = []
    test_rmses = []

    start_time = time.time()
    best_val_loss = float('inf')
    patience = 20
    patience_counter = 0

    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()

        epoch_train_loss = np.zeros(len(config.channels_joints))
        epoch_train_pcc = np.zeros(len(config.channels_joints))
        epoch_train_rmse = np.zeros(len(config.channels_joints))

        for i, (data_acc, data_gyr, target, data_EMG) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} Training")):
            optimizer.zero_grad()
            output, _ = model(data_acc.to(device).float(), data_gyr.to(device).float(), data_EMG.to(device).float())
            loss = criterion(output, target.to(device).float())
            loss.backward()
            optimizer.step()

            batch_rmse, batch_pcc, _, _, _ = RMSE_prediction(output.detach().cpu().numpy(), target.detach().cpu().numpy(), len(config.channels_joints), print_losses=False)
            epoch_train_loss += loss.detach().cpu().numpy()
            epoch_train_pcc += batch_pcc
            epoch_train_rmse += batch_rmse

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_pcc = epoch_train_pcc / len(train_loader)
        avg_train_rmse = epoch_train_rmse / len(train_loader)

        train_losses.append(avg_train_loss)
        train_pccs.append(avg_train_pcc)
        train_rmses.append(avg_train_rmse)

        # Evaluate on validation set every epoch
        avg_val_loss, avg_val_pcc, avg_val_rmse = evaluate_model(device, model, val_loader, criterion)
        val_losses.append(avg_val_loss)
        val_pccs.append(avg_val_pcc)
        val_rmses.append(avg_val_rmse)


        # Evaluate on test set and checkpoint every epoch
        avg_test_loss, avg_test_pcc, avg_test_rmse = evaluate_model(device, model, test_loader, criterion)
        test_losses.append(avg_test_loss)
        test_pccs.append(avg_test_pcc)
        test_rmses.append(avg_test_rmse)

        print(f"Epoch: {epoch + 1}, Training Loss: {np.mean(avg_train_loss)}, Validation Loss: {np.mean(avg_val_loss):.4f}", f"Test Loss: {np.mean(avg_test_loss):.4f}")
        print(f"Training PCC: {np.mean(avg_train_pcc)}, Validation PCC: {np.mean(avg_val_pcc):.4f}", f"Test PCC: {np.mean(avg_test_pcc):.4f}")


        if not os.path.exists(f"/content/MyDrive/MyDrive/models/{filename}"):
          os.makedirs(f"/content/MyDrive/MyDrive/models/{filename}")

        # Save checkpoint, including channel-wise metrics
        save_checkpoint(
            model,
            optimizer,
            epoch,
            f"/content/MyDrive/MyDrive/models/{filename}/{filename}_epoch_{epoch+1}.pth",
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            test_loss=avg_test_loss,
            channelwise_metrics={
                'train': {'pcc': avg_train_pcc, 'rmse': avg_train_rmse},
                'val': {'pcc': avg_val_pcc, 'rmse': avg_val_rmse},
                'test': {'pcc': avg_test_pcc, 'rmse': avg_test_rmse},
            }
        )

        # Early stopping logic
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), filename)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Stopping early after {epoch + 1} epochs")
            break

    end_time = time.time()
    print(f"Total training time: {end_time - start_time:.2f} seconds")

    return model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses


In [None]:
class DataSharder:
    def __init__(self, config, save_h5=False):
        self.config = config
        self.input_format = config.input_format
        self.data_folder_path = config.data_folder_name
        self.window_length = int(config.window_length)
        self.window_overlap = int(config.window_overlap)
        self.save_h5 = save_h5

    def load_data(self):
        print(f"Training subjects: {self.config.train_subjects}")
        print(f"Testing subjects: {self.config.test_subjects}")

        if self.input_format == 'wav':
            self._process_and_save_patients_wav(self.config.train_subjects, "train")
            self._process_and_save_patients_wav(self.config.test_subjects, "test")
        elif self.input_format == 'csv':
            self._process_and_save_patients_csv(self.config.train_subjects, "train")
            self._process_and_save_patients_csv(self.config.test_subjects, "test")
        else:
            raise ValueError(f"Unsupported input format: {self.input_format}")

    def _process_and_save_patients_wav(self, patient_id_list, split):
        total_data = []
        for patient_id in tqdm(patient_id_list, desc=f"Processing {split} patients"):
            for session_index in tqdm(range(len(self.config.train_subjects)), desc=f"Processing sessions for {patient_id}", leave=False):
                imu_data, imu_sample_rate = self._load_wav_file(patient_id, session_index, "IMU")
                joints_data, joints_sample_rate = self._load_wav_file(patient_id, session_index, "JOINTS")
                emg_data, emg_sample_rate = self._load_wav_file(patient_id, session_index, "EMG")

                imu_data = self._resample_data(imu_data, imu_sample_rate)
                joints_data = self._resample_data(joints_data, joints_sample_rate)
                emg_data = self._resample_data(emg_data, emg_sample_rate)

                combined_data = torch.cat((imu_data, joints_data, emg_data), dim=1)
                total_data.append(combined_data.cpu().numpy())

        if self.save_h5:
            self._save_to_h5(total_data, split)
        else:
            for combined_data in total_data:
                self._save_windowed_data(combined_data, patient_id, session_index, split)

    def _load_wav_file(self, patient_id, session_index, file_type):
        file_path = os.path.join(self.data_folder_path, patient_id, f"run{session_index}_{file_type}.wav")
        data, sample_rate = get_data_from_wav_file(file_path)
        return torch.tensor(data, dtype=torch.float32), sample_rate

    def _resample_data(self, data, sample_rate):
        if sample_rate != self.sample_rate:
            data = torch.nn.functional.interpolate(data.unsqueeze(0), size=self.sample_rate, mode='linear').squeeze(0)
        return data

    def _process_and_save_patients_csv(self, patient_id_list, split):
        column_names = None
        for patient_id in tqdm(patient_id_list, desc=f"Processing {split} patients"):
            combined_path = os.path.join(self.data_folder_path, patient_id, "combined")
            if not os.path.exists(combined_path):
                print(f"Directory {combined_path} does not exist. Skipping patient {patient_id}.")
                continue

            patient_files = os.listdir(combined_path)
            for session_file in tqdm(patient_files, desc=f"Processing sessions for {patient_id}", leave=False):
                data = pd.read_csv(os.path.join(combined_path, session_file))
                if column_names is None:
                    column_names = data.columns.tolist()  # Convert Index to list
                data_np = data.to_numpy()
                # Pad array to fit columns if necessary
                if data_np.shape[1] < len(column_names):
                    data_np = np.pad(data_np, ((0, 0), (0, len(column_names) - data_np.shape[1])), mode='constant')
                elif data_np.shape[1] > len(column_names):
                    # Extend column names to match the data shape
                    extra_columns = [f"extra_{i}" for i in range(data_np.shape[1] - len(column_names))]
                    column_names.extend(extra_columns)

                self._save_windowed_data(pd.DataFrame(data_np, columns=column_names), patient_id, session_file.split('.')[0], split, is_csv=True)

    def _save_windowed_data(self, data, patient_id, session_id, split, is_csv=False):
        dataset_folder = os.path.join(self.config.dataset_root, self.config.dataset_name, self.config.dataset_train_name if split == "train" else self.config.dataset_test_name)
        os.makedirs(dataset_folder, exist_ok=True)

        window_size = self.window_length
        overlap = self.window_overlap
        step_size = window_size - overlap

        data_info_list = []

        for i in tqdm(range(0, len(data) - window_size + 1, step_size), desc=f"Windowing data for {patient_id}_{session_id}", leave=False):
            windowed_data = data.iloc[i:i+window_size] if is_csv else data[i:i+window_size]
            if windowed_data.shape[0] < window_size:
                continue

            windowed_data_np = windowed_data.to_numpy() if is_csv else windowed_data.cpu().numpy()

            file_name = f"{patient_id}_session_{session_id}_window_{i}_ws{window_size}_ol{overlap}.csv"
            file_path = os.path.join(dataset_folder, file_name)
            pd.DataFrame(windowed_data_np, columns=data.columns if is_csv else None).to_csv(file_path, index=False)
            data_info_list.append({"file_name": file_name, "file_path": file_path})

        data_info_df = pd.DataFrame(data_info_list)
        data_info_df.to_csv(os.path.join(self.config.dataset_root, self.config.dataset_name, f"{split}_info.csv"), index=False, mode='a', header=not os.path.exists(os.path.join(self.config.dataset_root, self.config.dataset_name, f"{split}_info.csv")))


In [None]:
import shutil
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config(
    data_folder_name='/content/MyDrive/MyDrive/sd_datacollection',
    dataset_root='/content/datasets',
    dataset_train_name='train',
    dataset_test_name='test',
    batch_size=64,
    epochs=100,
    lr=0.001,
    window_length=100,
    window_overlap=0,
    input_format="csv",
    channels_imu_acc=['ACCX1', 'ACCY1', 'ACCZ1', 'ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_acc_test=['ACCX1', 'ACCY1', 'ACCZ1', 'ACCX2', 'ACCY2', 'ACCZ2', 'ACCX3', 'ACCY3', 'ACCZ3', 'ACCX4', 'ACCY4', 'ACCZ4', 'ACCX5', 'ACCY5', 'ACCZ5', 'ACCX6', 'ACCY6', 'ACCZ6'],
    channels_imu_gyr=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_imu_gyr_test=['GYROX1', 'GYROY1', 'GYROZ1', 'GYROX2', 'GYROY2', 'GYROZ2', 'GYROX3', 'GYROY3', 'GYROZ3', 'GYROX4', 'GYROY4', 'GYROZ4', 'GYROX5', 'GYROY5', 'GYROZ5', 'GYROX6', 'GYROY6', 'GYROZ6'],
    channels_joints=['elbow_flex_r', 'arm_flex_r', 'arm_add_r'],
    channels_emg=['IM EMG4', 'IM EMG5', 'IM EMG6'],
    train_subjects=['subject_2','subject_3','subject_4','subject_6','subject_7','subject_8', 'subject_9','subject_10', 'subject_11','subject_12','subject_13'],
    test_subjects=['subject_1']
)





In [None]:
reshard = True

if reshard:
  data_sharder = DataSharder(config)
  data_sharder.load_data()

  #copy new dataset to drive
  # shutil.copytree(os.path.join("/content/datasets",config.dataset_name),os.path.join("/content/MyDrive/MyDrive/datasets",config.dataset_name))

if not os.path.exists("/content/datasets"):
    #copy over
    shutil.copytree(os.path.join("/content/MyDrive/MyDrive/datasets",config.dataset_name),os.path.join("/content/datasets",config.dataset_name))



Training subjects: ['subject_2', 'subject_3', 'subject_4', 'subject_6', 'subject_7', 'subject_8', 'subject_9', 'subject_10', 'subject_11', 'subject_12', 'subject_13']
Testing subjects: ['subject_1']


Processing train patients:   0%|          | 0/11 [00:00<?, ?it/s]

Processing sessions for subject_2:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_overheadreach_max_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_overheadreach_90_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_overheadreach_45_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elboxflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elbowrotation_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elboxflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elboxflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_crossbody_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_crossbody_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_armSwing_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_2_P002_T001_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_3:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_overheadreach45_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_overheadreach90_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowflexion_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowflexion_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowflexion_slow_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T002_elbowrotation_slow_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowrotation_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_armswing_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_armswing_veryfast_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_armswing_fast_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_elbowrotation_normal_combined:   0%|          | 0/20 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T003_armswing_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_crossbody_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_crossbody_fast_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_overheadreachMax_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_3_P003_T001_crossbody_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Processing sessions for subject_4:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowrotation_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_crossbody_slow_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_crossbody_fast_combined:   0%|          | 0/36 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowrotation_fast_combined:   0%|          | 0/35 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_crossbody_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowrotation_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_overheadreach_max_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowflexion_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_4_P004_T001_elbowflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_6:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_slow_combined:   0%|          | 0/28 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_overheadreach_max_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_overheadreach_45_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T002_crossbody_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowrotation_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_6_P006_T001_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_7:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T002_armSwing_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_overheadreach_max_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_elbowrotation_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_crossbody_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_7_P007_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_8:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowflexion_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_overheadreach_max_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_crossbody_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowflexion_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_8_P008_T001_elbowrotation_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_9:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_elbowflexion_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_armswing_veryfast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_elbowflexion_slow_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_overheadreach_max_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowrotation_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T001_overheadreach_45_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_crossbody_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowrotation_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_elbowrotation_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_9_P009_T002_crossbody_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Processing sessions for subject_10:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_elbowflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T002_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T002_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_crossbody_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_overheadreach_max_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T002_elbowrotation_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_10_P010_T001_elbowflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_11:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_overheadreach_90_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_overheadreach_max_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_crossbody_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowrotation_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_11_P011_T001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Processing sessions for subject_12:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowrotation_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_slow_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_veryfast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_armSwing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowflexion_slow_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowrotation_fast_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowrotation_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowflexion_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_overheadreach_max_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_crossbody_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_elbowflexion_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_12_P0012_T001_crossbody_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Processing sessions for subject_13:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_overheadreach_max_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_veryfast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_crossbody_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_crossbody_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_overheadreach_90_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowrotation_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowflexion_slow_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowrotation_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_crossbody_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowflexion_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_overheadreach_45_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowrotation_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_fast_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_elbowflexion_normal_combined:   0%|          | 0/30 [00:00<?, ?it/s]

Windowing data for subject_13_p13_t001_armswing_normal_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Processing test patients:   0%|          | 0/1 [00:00<?, ?it/s]

Processing sessions for subject_1:   0%|          | 0/16 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_crossbody_fast_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowrotation_fast_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowflexion_fast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_fast_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_veryfast_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_overheadreach_max_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowflexion_slow_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_overheadreach_90_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_armSwing_slow_combined:   0%|          | 0/31 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowrotation_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowrotation_slow_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_overheadreach_45_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_elbowflexion_normal_combined:   0%|          | 0/32 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_crossbody_normal_combined:   0%|          | 0/33 [00:00<?, ?it/s]

Windowing data for subject_1_P001_T001_crossbody_slow_combined:   0%|          | 0/33 [00:00<?, ?it/s]

In [None]:
# Create datasets
train_dataset = ImuJointPairDataset(config, split='train')
test_dataset = ImuJointPairDataset(config, split='test')

# Setup validation dataset
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [int(0.9 * len(train_dataset)), len(train_dataset) - int(0.9 * len(train_dataset))])

# Setup dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

# Train the model
model = teacher(
    input_acc=len(config.channels_imu_acc),
    input_gyr=len(config.channels_imu_gyr),
    input_emg=len(config.channels_emg)
)

In [None]:

acc,gyro,target,emg=next(iter(train_loader))
print(acc.shape)
print(gyro.shape)
print(target.shape)
print(emg.shape)

print(acc.dtype)
print(gyro.dtype)
print(target.dtype)
print(emg.dtype)


torch.Size([64, 100, 18])
torch.Size([64, 100, 18])
torch.Size([64, 100, 3])
torch.Size([64, 100, 3])
torch.float32
torch.float32
torch.float32
torch.float32


In [None]:
model = teacher(18, 18, 3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


#train teacher
model, train_losses, val_losses, test_losses, train_pccs, val_pccs, test_pccs, train_rmses, val_rmses, test_rmses = train_teacher(device,train_loader, val_loader, test_loader,config.lr, config.epochs, model, 'new_opensimmodel_first_run_no_subject_5_RMSE_nooverlap')

Epoch 1/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 34.08233105508905, Validation Loss: 22.8872 Test Loss: 32.6169
Training PCC: 0.5132993794868798, Validation PCC: 0.6909 Test PCC: 0.4603
Checkpoint saved for epoch 1


Epoch 2/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 19.759027305402252, Validation Loss: 16.5516 Test Loss: 27.0007
Training PCC: 0.7429509888156699, Validation PCC: 0.8353 Test PCC: 0.5573
Checkpoint saved for epoch 2


Epoch 3/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 15.185346678683635, Validation Loss: 13.6382 Test Loss: 19.3200
Training PCC: 0.8871941629448271, Validation PCC: 0.9139 Test PCC: 0.6878
Checkpoint saved for epoch 3


Epoch 4/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 13.172833668558221, Validation Loss: 12.5763 Test Loss: 19.9278
Training PCC: 0.9241646310780952, Validation PCC: 0.9376 Test PCC: 0.6962
Checkpoint saved for epoch 4


Epoch 5/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 12.10286529440629, Validation Loss: 10.7825 Test Loss: 21.7092
Training PCC: 0.9391314823680461, Validation PCC: 0.9484 Test PCC: 0.7567
Checkpoint saved for epoch 5


Epoch 6/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 11.752138351139271, Validation Loss: 10.5500 Test Loss: 22.4936
Training PCC: 0.944422984000778, Validation PCC: 0.9538 Test PCC: 0.7434
Checkpoint saved for epoch 6


Epoch 7/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 11.17415163391515, Validation Loss: 11.0369 Test Loss: 18.9545
Training PCC: 0.9492129806955353, Validation PCC: 0.9537 Test PCC: 0.7046
Checkpoint saved for epoch 7


Epoch 8/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 10.93040741117377, Validation Loss: 9.6764 Test Loss: 19.0242
Training PCC: 0.953714057083224, Validation PCC: 0.9613 Test PCC: 0.7550
Checkpoint saved for epoch 8


Epoch 9/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 10.19405534392909, Validation Loss: 9.3938 Test Loss: 19.4740
Training PCC: 0.9584802322136978, Validation PCC: 0.9609 Test PCC: 0.7609
Checkpoint saved for epoch 9


Epoch 10/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 9.97613315205825, Validation Loss: 9.3555 Test Loss: 19.7569
Training PCC: 0.9618371661840442, Validation PCC: 0.9647 Test PCC: 0.7567
Checkpoint saved for epoch 10


Epoch 11/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 11, Training Loss: 9.499523840452495, Validation Loss: 8.8732 Test Loss: 19.0540
Training PCC: 0.9638623727197119, Validation PCC: 0.9661 Test PCC: 0.7179
Checkpoint saved for epoch 11


Epoch 12/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 12, Training Loss: 9.292304685241298, Validation Loss: 9.2585 Test Loss: 21.3547
Training PCC: 0.9658115153560104, Validation PCC: 0.9658 Test PCC: 0.7385
Checkpoint saved for epoch 12


Epoch 13/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 13, Training Loss: 8.97429937437961, Validation Loss: 8.6314 Test Loss: 20.7334
Training PCC: 0.9678251074130227, Validation PCC: 0.9682 Test PCC: 0.7248
Checkpoint saved for epoch 13


Epoch 14/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 14, Training Loss: 8.822789669036865, Validation Loss: 8.6357 Test Loss: 19.3014
Training PCC: 0.9692473794566422, Validation PCC: 0.9685 Test PCC: 0.7847
Checkpoint saved for epoch 14


Epoch 15/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 15, Training Loss: 8.749000235607749, Validation Loss: 8.5361 Test Loss: 20.2931
Training PCC: 0.9699667411044267, Validation PCC: 0.9698 Test PCC: 0.7342
Checkpoint saved for epoch 15


Epoch 16/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 16, Training Loss: 8.458965790899176, Validation Loss: 8.2395 Test Loss: 20.6269
Training PCC: 0.9711870014505307, Validation PCC: 0.9709 Test PCC: 0.7396
Checkpoint saved for epoch 16


Epoch 17/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 17, Training Loss: 8.420914744075976, Validation Loss: 8.8616 Test Loss: 21.0878
Training PCC: 0.9716166556569051, Validation PCC: 0.9689 Test PCC: 0.7589
Checkpoint saved for epoch 17


Epoch 18/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 18, Training Loss: 8.062905725679899, Validation Loss: 8.9092 Test Loss: 19.8992
Training PCC: 0.973058911189432, Validation PCC: 0.9694 Test PCC: 0.7498
Checkpoint saved for epoch 18


Epoch 19/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 19, Training Loss: 8.034296223991795, Validation Loss: 8.6475 Test Loss: 19.2850
Training PCC: 0.9741782054868358, Validation PCC: 0.9689 Test PCC: 0.7566
Checkpoint saved for epoch 19


Epoch 20/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 20, Training Loss: 7.803203946665714, Validation Loss: 8.7218 Test Loss: 19.5818
Training PCC: 0.9755132550903686, Validation PCC: 0.9692 Test PCC: 0.7681
Checkpoint saved for epoch 20


Epoch 21/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 21, Training Loss: 7.699295878410339, Validation Loss: 8.5161 Test Loss: 19.9621
Training PCC: 0.9764166730222383, Validation PCC: 0.9697 Test PCC: 0.7725
Checkpoint saved for epoch 21


Epoch 22/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 22, Training Loss: 7.5842016119706015, Validation Loss: 8.0876 Test Loss: 18.8764
Training PCC: 0.9769253823542021, Validation PCC: 0.9728 Test PCC: 0.7930
Checkpoint saved for epoch 22


Epoch 23/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 23, Training Loss: 7.363892297995718, Validation Loss: 8.6711 Test Loss: 20.9367
Training PCC: 0.9782662358106996, Validation PCC: 0.9698 Test PCC: 0.7565
Checkpoint saved for epoch 23


Epoch 24/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 24, Training Loss: 7.042769199923465, Validation Loss: 8.6850 Test Loss: 18.9881
Training PCC: 0.9799690657547773, Validation PCC: 0.9700 Test PCC: 0.7893
Checkpoint saved for epoch 24


Epoch 25/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 25, Training Loss: 7.07309731684233, Validation Loss: 8.0350 Test Loss: 19.6054
Training PCC: 0.9798810672518804, Validation PCC: 0.9712 Test PCC: 0.7731
Checkpoint saved for epoch 25


Epoch 26/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 26, Training Loss: 6.70101269295341, Validation Loss: 8.2908 Test Loss: 20.6620
Training PCC: 0.9814164393677, Validation PCC: 0.9718 Test PCC: 0.7424
Checkpoint saved for epoch 26


Epoch 27/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 27, Training Loss: 6.786803766300804, Validation Loss: 8.2338 Test Loss: 20.2670
Training PCC: 0.9812757681253444, Validation PCC: 0.9715 Test PCC: 0.7551
Checkpoint saved for epoch 27


Epoch 28/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 28, Training Loss: 6.773084301697581, Validation Loss: 8.2752 Test Loss: 19.3079
Training PCC: 0.9816177256101356, Validation PCC: 0.9731 Test PCC: 0.7574
Checkpoint saved for epoch 28


Epoch 29/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 29, Training Loss: 6.36551232714402, Validation Loss: 7.9009 Test Loss: 18.2849
Training PCC: 0.9835120909301726, Validation PCC: 0.9740 Test PCC: 0.7886
Checkpoint saved for epoch 29


Epoch 30/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 30, Training Loss: 6.295484586765892, Validation Loss: 7.9511 Test Loss: 19.2573
Training PCC: 0.9840819361207499, Validation PCC: 0.9738 Test PCC: 0.7710
Checkpoint saved for epoch 30


Epoch 31/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 31, Training Loss: 6.078547565560592, Validation Loss: 7.8729 Test Loss: 18.4990
Training PCC: 0.9849594695664764, Validation PCC: 0.9738 Test PCC: 0.7867
Checkpoint saved for epoch 31


Epoch 32/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 32, Training Loss: 6.166398682092367, Validation Loss: 7.7797 Test Loss: 18.8413
Training PCC: 0.9849181259508507, Validation PCC: 0.9751 Test PCC: 0.7939
Checkpoint saved for epoch 32


Epoch 33/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 33, Training Loss: 5.975291026265997, Validation Loss: 7.5023 Test Loss: 20.3691
Training PCC: 0.9856202594166379, Validation PCC: 0.9765 Test PCC: 0.7910
Checkpoint saved for epoch 33


Epoch 34/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 34, Training Loss: 6.0092544994856185, Validation Loss: 7.7771 Test Loss: 17.6070
Training PCC: 0.985554047584583, Validation PCC: 0.9759 Test PCC: 0.8092
Checkpoint saved for epoch 34


Epoch 35/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 35, Training Loss: 5.969285977514166, Validation Loss: 7.7249 Test Loss: 19.3318
Training PCC: 0.9859182889811828, Validation PCC: 0.9757 Test PCC: 0.7975
Checkpoint saved for epoch 35


Epoch 36/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 36, Training Loss: 5.709517046024925, Validation Loss: 7.6730 Test Loss: 18.9112
Training PCC: 0.9867734407782383, Validation PCC: 0.9754 Test PCC: 0.7954
Checkpoint saved for epoch 36


Epoch 37/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 37, Training Loss: 5.90269893094113, Validation Loss: 8.0783 Test Loss: 18.8854
Training PCC: 0.9863791307449806, Validation PCC: 0.9733 Test PCC: 0.7801
Checkpoint saved for epoch 37


Epoch 38/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 38, Training Loss: 5.599068754597714, Validation Loss: 7.1524 Test Loss: 18.9846
Training PCC: 0.987112411577737, Validation PCC: 0.9779 Test PCC: 0.7819
Checkpoint saved for epoch 38


Epoch 39/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 39, Training Loss: 5.616804229585749, Validation Loss: 7.9200 Test Loss: 19.7729
Training PCC: 0.9873765076808563, Validation PCC: 0.9756 Test PCC: 0.8028
Checkpoint saved for epoch 39


Epoch 40/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 40, Training Loss: 5.403616673067996, Validation Loss: 7.3534 Test Loss: 19.5027
Training PCC: 0.9882156513327599, Validation PCC: 0.9773 Test PCC: 0.7982
Checkpoint saved for epoch 40


Epoch 41/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 41, Training Loss: 5.385077859226026, Validation Loss: 7.1760 Test Loss: 17.9027
Training PCC: 0.9883769767116001, Validation PCC: 0.9773 Test PCC: 0.8066
Checkpoint saved for epoch 41


Epoch 42/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 42, Training Loss: 5.243100800012288, Validation Loss: 7.4839 Test Loss: 18.4189
Training PCC: 0.9886734361288715, Validation PCC: 0.9762 Test PCC: 0.8216
Checkpoint saved for epoch 42


Epoch 43/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 43, Training Loss: 5.4086282065040185, Validation Loss: 7.5420 Test Loss: 19.2367
Training PCC: 0.988480958499042, Validation PCC: 0.9763 Test PCC: 0.8050
Checkpoint saved for epoch 43


Epoch 44/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 44, Training Loss: 5.374098175450375, Validation Loss: 7.3577 Test Loss: 17.6849
Training PCC: 0.9884791605854358, Validation PCC: 0.9769 Test PCC: 0.8174
Checkpoint saved for epoch 44


Epoch 45/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 45, Training Loss: 5.23574079337873, Validation Loss: 7.3821 Test Loss: 18.1471
Training PCC: 0.9890915977949267, Validation PCC: 0.9767 Test PCC: 0.8059
Checkpoint saved for epoch 45


Epoch 46/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 46, Training Loss: 5.387057059689572, Validation Loss: 7.3216 Test Loss: 19.0856
Training PCC: 0.988532396547999, Validation PCC: 0.9769 Test PCC: 0.8119
Checkpoint saved for epoch 46


Epoch 47/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 47, Training Loss: 5.224038851888556, Validation Loss: 7.3225 Test Loss: 19.2292
Training PCC: 0.9889566740095378, Validation PCC: 0.9774 Test PCC: 0.8042
Checkpoint saved for epoch 47


Epoch 48/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch: 48, Training Loss: 6.015235800492136, Validation Loss: 7.4854 Test Loss: 18.7264
Training PCC: 0.9842142132811339, Validation PCC: 0.9768 Test PCC: 0.8046
Checkpoint saved for epoch 48


Epoch 49/100 Training:   0%|          | 0/76 [00:00<?, ?it/s]

In [None]:
# prompt: plot necessary values

import matplotlib.pyplot as plt

plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Test Loss')
plt.legend()
plt.show()


plt.plot(train_pccs, label='Train PCC')
plt.plot(val_pccs, label='Validation PCC')
plt.xlabel('Epoch')
plt.ylabel('PCC')
plt.title('Training and Validation PCC')
plt.legend()
plt.show()

plt.plot(test_pccs, label='Test PCC')
plt.xlabel('Epoch')
plt.ylabel('PCC')
plt.title('Test PCC')
plt.legend()
plt.show()

plt.plot(train_rmses, label='Train RMSE')
plt.plot(val_rmses, label='Validation RMSE')
plt.xlabel('Epoch')
plt.ylabel('RMSE')
plt.title('Training and Validation RMSE')
plt.legend()
plt.show()

plt.plot(test_rmses, label='Test RMSE')
plt.xlabel('Epoch')
plt.ylabel('RMSE')
plt.title('Test RMSE')
plt.legend()
plt.show()