In [13]:
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import wfdb
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import neurokit2 as nk
from scipy.signal import resample, medfilt
import pywt
import pickle
import time
import os
import csv
import mne
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from zoneinfo import ZoneInfo

In [14]:
def moving_average(signal, window_size=10):
    return np.convolve(signal, np.ones(window_size)/window_size, mode='same')
def preprocessing_data(data):
    band_passed_ecg =  nk.signal_filter(data, sampling_rate=360, lowcut=0.1, highcut=100, method='butterworth_zi', order = 2)
    emg = moving_average(band_passed_ecg, window_size=10)
    # Step 4: Downsample the filtered ECG signal 
    downsampled_ecg = nk.signal_resample(emg, sampling_rate=360, desired_sampling_rate=100)
    return downsampled_ecg

In [15]:
def load_record(record_id, seq_len=1000, stride=1000):
    record_name = str(record_id)
    if not os.path.exists(f"mit_bih/{record_name}.dat"):
        wfdb.dl_database('mitdb', './', records=[record_name])
    record = wfdb.rdrecord('mit_bih/' + record_name)
    
    #Extract each channel
    signal_data = record.p_signal
    ch1 = preprocessing_data(signal_data[:, 0]).reshape(-1, 1)
    ch2 = preprocessing_data(signal_data[:, 1]).reshape(-1, 1)

    #Scale the signal
    scaler = MinMaxScaler(feature_range=(-0.5, 0.5))
    ch1_scale = scaler.fit_transform(ch1)
    ch2_scale = scaler.fit_transform(ch2)
    
    # Create sequences using sliding window
    seq_ch1 = [ch1_scale[i:i+seq_len] for i in range(0, len(ch1_scale)-seq_len+1, stride)]
    seq_ch2 = [ch2_scale[i:i+seq_len] for i in range(0, len(ch2_scale)-seq_len+1, stride)]
    seq_ch1.extend(seq_ch2)
    return seq_ch1

def load_multiple_records(record_ranges, seq_len=1000, stride_len=100):
    all_data = []
    for start, end in tqdm(record_ranges):
        for record_id in range(start, end + 1):
            record_data = load_record(record_id, seq_len, stride_len)
            all_data.extend(record_data)
    # Save to pkl
    with open('database.pkl', 'wb') as f:
        pickle.dump(all_data, f)
    
    return all_data

In [16]:
def downsample(data, original_rate, target_rate):
    num_samples = int(len(data) * target_rate / original_rate)
    return resample(data, num_samples)

def calculate_l1(predictions, targets):
    l1_norm = torch.norm(predictions-targets, p=1)
    l1_sum = torch.sum(torch.abs(targets))
    accuracy = 100 * (1 - (l1_norm / l1_sum))
    return accuracy

def calculate_l2(predictions, targets):
    l2_norm = torch.norm(predictions - targets, p=2)
    l2_sum = torch.norm(targets, p=2)
    accuracy = 100 * (1 - (l2_norm / l2_sum))
    return accuracy

def calculate_MAPE(predictions, targets):
    mape = torch.mean(torch.abs((targets - predictions) / (targets)))
    return mape

In [17]:
def fixed_position_mask_peaks(rpeaks, seq, mask_length, num_peaks_to_mask):
    mask = np.zeros_like(seq, dtype=bool)  # Create a mask with the same length as seq, initially all False

    if len(rpeaks) > 0:  # If there are detected peaks
        # Randomly select 'num_peaks_to_mask' peaks from rpeaks
        selected_peaks = rpeaks[1:2*num_peaks_to_mask + 1: 2]

        for peak in selected_peaks:
            mask_half = mask_length // 2
            start = int(max(0, peak - mask_half))
            end = int(min(len(seq), peak + mask_half))
            mask[start:end] = True
    
    return mask

In [18]:
class ECGDataset(Dataset):
    def __init__(self, data, seq_len=1000, mask_length=60, target_rate=100,num_peaks_to_mask=1):
        """
        Args:
            data (np.array): Full ECG signal data.
            seq_len (int): Length of each sequence before downsampling.
            mask_length (int): Length of masking window around the first R-peak.
            original_rate (int): Original sampling rate of the data (default: 360 Hz).
            target_rate (int): Target sampling rate after downsampling (default: 100 Hz).
        """
        self.data = data
        self.seq_len = seq_len
        self.mask_length = mask_length
        self.target_rate = target_rate
        self.num_peaks_to_mask = num_peaks_to_mask
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Adjust index for downsampled sequences
        seq = self.data[idx].flatten()
        # Detect R-peaks in the sequence
        rpeaks = nk.ecg_findpeaks(seq, sampling_rate=self.target_rate)['ECG_R_Peaks']

        # Create masking array
        mask = np.zeros_like(seq, dtype=bool)
        mask = fixed_position_mask_peaks(rpeaks, seq, self.mask_length, self.num_peaks_to_mask)
        
        masked_seq = seq.copy()
        masked_seq[mask] = 0  # Apply masking

        mask = torch.tensor(mask, dtype=torch.bool)
        return torch.tensor(masked_seq, dtype=torch.float32), torch.tensor(seq, dtype=torch.float32), mask

In [25]:
class Conv_Block(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64):
        super().__init__()
        self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(embed_dim, embed_dim * 2, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(embed_dim * 2, hidden_dim, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Input: (B, L) → (B, 1, L)
        x = x.unsqueeze(1)
        
        x = F.leaky_relu(self.conv1(x), negative_slope=0.01)
        x = self.pool1(x)
        x = F.leaky_relu(self.conv2(x), negative_slope=0.01)
        x = self.pool2(x)
        x = F.leaky_relu(self.conv3(x), negative_slope=0.01)
        return x

class Encoder1D_Mask(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64, lstm_hidden_dim=128):
        super().__init__()
        self.ConvBlock = Conv_Block(embed_dim, hidden_dim)
        self.bilstm = nn.LSTM(hidden_dim, lstm_hidden_dim,
                              bidirectional=True, batch_first=True)

    def forward(self, x):
        # Encode each branch
        x = self.ConvBlock(x)
        x = x.permute(0, 2, 1)
        x, _ = self.bilstm(x)
        x = x.permute(0, 2, 1)  # (B, 2*lstm_hidden_dim, L/4)
        return x


class Decoder1D(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64, lstm_hidden_dim=128):
        super().__init__()
        self.deconv1 = nn.Conv1d(lstm_hidden_dim * 2, embed_dim * 2,
                                 kernel_size=3, stride=1, padding=1)
        self.deconv2 = nn.Conv1d(embed_dim * 2, embed_dim,
                                 kernel_size=3, stride=1, padding=1)
        self.deconv3 = nn.Conv1d(embed_dim, 1,
                                 kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.leaky_relu(self.deconv1(x), negative_slope=0.01)
        x = F.interpolate(x, scale_factor=2, mode='linear', align_corners=True)

        x = F.leaky_relu(self.deconv2(x), negative_slope=0.01)
        x = F.interpolate(x, scale_factor=2, mode='linear', align_corners=True)

        x = self.deconv3(x)  # (B, 1, L)
        return x  # keep (B, 1, L)


class MAE1D_Mask(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64, lstm_hidden_dim=128):
        super().__init__()
        self.encoder = Encoder1D_Mask(embed_dim, hidden_dim, lstm_hidden_dim)
        self.decoder = Decoder1D(embed_dim, hidden_dim, lstm_hidden_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x.squeeze(1)

In [20]:
# Data preparation
def prepare_data(data, seq_len, num_rpeaks,batch_size=128):
    dataset = ECGDataset(data, seq_len, num_peaks_to_mask = num_rpeaks)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

In [28]:
class Trainer:
    def __init__(self, model, criterion, optimizer, seq_len=5000, num_rpeak=1, log_dir=None, test_case = True):
        self.seq_len = seq_len
        self.num_rpeak = num_rpeak
        self.device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
        self.model = model.to(self.device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_losses = []
        self.val_losses = []
        self.log_dir = log_dir
        self.test_case = test_case
        self.cur_time = datetime.now(ZoneInfo("Australia/Sydney")).strftime("%H%M_%d%m%Y")
        if self.log_dir is None:
            self.log_dir = f'runs/seq{seq_len}_rpeak{num_rpeak}_{self.cur_time}'
        if self.test_case == False:
            self.writer = SummaryWriter(log_dir=self.log_dir)

    def run(self, train_loader, val_loader, test_loader, epochs=100, patience = 10):
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0
            for masked_seq, original_seq, mask in train_loader:
                masked_seq, original_seq, mask = masked_seq.to(self.device), original_seq.to(self.device), mask.to(self.device)

                self.optimizer.zero_grad()
                reconstructed = self.model(masked_seq)

                loss_global = self.criterion(reconstructed, original_seq)
                loss_masked = self.criterion(reconstructed[mask], original_seq[mask])
                loss = 0.9 * loss_masked + 0.1 * loss_global

                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)
            self.train_losses.append(avg_train_loss)

            avg_val_loss, avg_l1, avg_l2, avg_MAPE = self.evaluate(epoch=epoch, data_loader = val_loader)
            self.val_losses.append(avg_val_loss)

            print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
                  f"L1 Accuracy: {avg_l1:.2f}, L2 Accuracy: {avg_l2:.2f}, MAPE Accuracy: {avg_MAPE:.2f}")
            if self.test_case == False:
                self.writer.add_scalar("Loss/Train", avg_train_loss, epoch)
                self.writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
                self.writer.add_scalar("Accuracy/L1", avg_l1, epoch)
                self.writer.add_scalar("Accuracy/L2", avg_l2, epoch)
                self.writer.add_scalar("Accuracy/MAPE", avg_MAPE, epoch)

            if epoch % 10 == 0:
                print(f"{self.log_dir}/model_{self.seq_len}_{self.num_rpeak}_{self.cur_time}.pth")
                if self.test_case == False:
                    print("Saving the best model checkpoint...")
                    torch.save(self.model.state_dict(), f"{self.log_dir}/model_{self.seq_len}_{self.num_rpeak}_{self.cur_time}.pth")

        self.writer.close()
        if test_loader is None:
            self.final_evaluation(val_loader)
        else: 
            self.final_evaluation(test_loader)
            
    def evaluate(self, data_loader, epoch= None):
        self.model.eval()
        val_loss = 0.0
        l1_accuracy = 0.0
        l2_accuracy = 0.0
        MAPE_accuracy = 0.0

        with torch.no_grad():
            for masked_seq, original_seq, mask in data_loader:
                masked_seq, original_seq, mask = masked_seq.to(self.device), original_seq.to(self.device), mask.to(self.device)
                reconstructed = self.model(masked_seq)
                masked_reconstructed = reconstructed[mask]
                masked_original = original_seq[mask]
                loss = self.criterion(masked_reconstructed, masked_original)
                val_loss += loss.item()
                l1_accuracy += calculate_l1(masked_reconstructed, masked_original)
                l2_accuracy += calculate_l2(masked_reconstructed, masked_original)
                MAPE_accuracy += calculate_MAPE(masked_reconstructed, masked_original)

        avg_val_loss = val_loss / len(data_loader)
        avg_l1 = l1_accuracy / len(data_loader)
        avg_l2 = l2_accuracy / len(data_loader)
        avg_MAPE = MAPE_accuracy / len(data_loader)

        if epoch is None:
            print(f"\nValidation Loss: {avg_val_loss:.4f}")
            print(f"L1 Accuracy: {avg_l1:.4f}")
            print(f"L2 Accuracy: {avg_l2:.4f}")
            print(f"MAPE Accuracy: {avg_MAPE:.2f}%")

        return avg_val_loss, avg_l1, avg_l2, avg_MAPE
    def plot_predictions_with_residuals(self, data_loader, sample_len = None, num_samples=2 ):
        self.model.eval()
        all_actual, all_predicted, all_masks = [], [], []
        if sample_len is None:
            sample_len = self.seq_len
        
        with torch.no_grad():
            for masked_seq, original_seq, mask in data_loader:
                masked_seq, original_seq = masked_seq.to(self.device), original_seq.to(self.device)
                reconstructed = self.model(masked_seq)
                all_actual.extend(original_seq.cpu().numpy())  # original (unmasked) ECG
                all_predicted.extend(reconstructed.cpu().numpy())  # predictions
                all_masks.extend(mask.cpu().numpy()) # Masked regions
     
                break  # Take only one batch for plotting
        
        # Flatten the lists for metric calculation
        all_actual_flat = np.concatenate(all_actual)
        all_predicted_flat = np.concatenate(all_predicted)
        
        # Calculate metrics
        mse = mean_squared_error(all_actual_flat, all_predicted_flat)
        mae = mean_absolute_error(all_actual_flat, all_predicted_flat)
        r2 = r2_score(all_actual_flat, all_predicted_flat)
    
        print(f"Mean Squared Error (MSE): {mse:.4f}")
        print(f"Mean Absolute Error (MAE): {mae:.4f}")
        print(f"R-squared (R²): {r2:.4f}")
        
        # Visualize predictions vs. actual data with residuals
        plt.figure(figsize=(15, num_samples * 5))
        for i in range(min(num_samples, len(all_actual))):
            original = all_actual[i]
            predicted = all_predicted[i]
            mask = all_masks[i]
    
            # Plot predictions vs actual data
            plt.subplot(num_samples, 2, i * 2 + 1)
            plt.plot(original[:sample_len], label="Original", color="blue", alpha=0.7)
            plt.plot(predicted[:sample_len], label="Predicted", color="orange", alpha=0.7)
    
            # Highlight masked regions
            masked_indices = np.where(mask)[0]
            plt.scatter(masked_indices, original[masked_indices], color="red", label="Masked Regions", alpha=0.7)
    
            plt.title(f"Sample {i + 1} - Predictions")
            plt.legend()
    
            # Compute and plot residuals
            residuals = np.array(original) - np.array(predicted)
            plt.subplot(num_samples, 2, i * 2 + 2)
            plt.plot(residuals[:sample_len], label="Residuals", color="green")
            plt.axhline(0, color="black", linestyle="--")
            plt.title(f"Sample {i + 1} - Residuals")
            plt.legend()
    
        plt.tight_layout()
        plt.show()

    def plot_loss_curves(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label="Train Loss", marker='o')
        plt.plot(self.val_losses, label="Validation Loss", marker='o')
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Loss Curves")
        plt.legend()
        plt.grid()
        plt.show()
        
    def final_evaluation(self, data_loader):
        avg_val_loss, avg_l1, avg_l2, avg_MAPE = self.evaluate(data_loader, epoch= None)
        with open('runs/Model_Running.csv', 'a') as f_object:
            writer_object = csv.writer(f_object)
            writer_object.writerow([f'model_{self.seq_len}_{self.num_rpeak}_{self.cur_time}.pth', avg_l1.cpu().numpy(), avg_l2.cpu().numpy(), avg_MAPE.cpu().numpy(), self.seq_len, self.num_rpeak])
            f_object.close()  
        self.plot_predictions_with_residuals(data_loader = data_loader, sample_len = 1000)
        self.plot_loss_curves()

In [None]:
record_ranges = [(100,109),(111, 119),(121,124),(200,203),(205,205),(207,210),(212,215),(217,217),(219,223),(228,228),(230,232)]
print("Working with data")
seq_len = 1000
stride = 1000
data = np.array(load_multiple_records(record_ranges, seq_len, stride))
test_data = load_multiple_records([(233,234)], seq_len, stride)

for num in [5]:
    print("="*20)
    print(f"Training the model with {num} R peaks")
    print("="*20)
    
    print("Preparing the data")
    train_loader, val_loader = prepare_data(data,num_rpeaks = num, seq_len=seq_len)
    test_dataset = ECGDataset(test_data, seq_len, num_peaks_to_mask = num)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    
    print("Training the model")
    criterion = nn.MSELoss()
    mae_model = MAE1D_Mask()
    optimizer = optim.Adam(mae_model.parameters(), lr=0.001)    

    model_trainer = Trainer(model = mae_model, criterion = criterion, optimizer = optimizer,  seq_len=seq_len, num_rpeak=num, test_case = False)
    model_trainer.run(train_loader, val_loader, test_loader, epochs = 200)

Working with data


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:03<00:00,  2.80it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.94it/s]


Training the model with 5 R peaks
Preparing the data
Training the model
Epoch 0, Train Loss: 0.0100, Val Loss: 0.0062, L1 Accuracy: 51.74, L2 Accuracy: 35.99, MAPE Accuracy: 3.04
runs/seq1000_rpeak5_1208_08092025/model_1000_5_1208_08092025.pth
Saving the best model checkpoint...
Epoch 1, Train Loss: 0.0056, Val Loss: 0.0056, L1 Accuracy: 53.28, L2 Accuracy: 39.03, MAPE Accuracy: 2.92
Epoch 2, Train Loss: 0.0051, Val Loss: 0.0051, L1 Accuracy: 55.33, L2 Accuracy: 41.90, MAPE Accuracy: 3.23
Epoch 3, Train Loss: 0.0045, Val Loss: 0.0046, L1 Accuracy: 55.42, L2 Accuracy: 45.00, MAPE Accuracy: 2.72
Epoch 4, Train Loss: 0.0038, Val Loss: 0.0037, L1 Accuracy: 61.55, L2 Accuracy: 50.71, MAPE Accuracy: 3.24
Epoch 5, Train Loss: 0.0033, Val Loss: 0.0032, L1 Accuracy: 63.87, L2 Accuracy: 53.84, MAPE Accuracy: 3.11
Epoch 6, Train Loss: 0.0030, Val Loss: 0.0029, L1 Accuracy: 64.67, L2 Accuracy: 56.11, MAPE Accuracy: 3.44
Epoch 7, Train Loss: 0.0026, Val Loss: 0.0026, L1 Accuracy: 67.37, L2 Accuracy

In [None]:
class Conv_Block(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64):
        super().__init__()
        self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(embed_dim, embed_dim * 2, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(embed_dim * 2, hidden_dim, kernel_size=3, stride=1, padding=1)

    def forward(self, x,mask):
        # Input: (B, L) → (B, 1, L)
        x = x.unsqueeze(1)
        mask1 = self.pool1(mask)
        mask2 = self.pool2(mask)
        
        x = F.leaky_relu(self.conv1(x), negative_slope=0.01)
        x = self.pool1(x)
        x = x * mask1
        x = F.leaky_relu(self.conv2(x), negative_slope=0.01)
        x = self.pool2(x)
        x = x * mask2
        x = F.leaky_relu(self.conv3(x), negative_slope=0.01)
        return x

class Encoder1D_Mask(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64, lstm_hidden_dim=128):
        super().__init__()
        self.ConvBlock = Conv_Block(embed_dim, hidden_dim)
        self.bilstm = nn.LSTM(hidden_dim, lstm_hidden_dim,
                              bidirectional=True, batch_first=True)

    def forward(self, x,mask):
        # Encode each branch
        x = self.ConvBlock(x,mask)
        x = x.permute(0, 2, 1)
        x, _ = self.bilstm(x)
        x = x.permute(0, 2, 1)  # (B, 2*lstm_hidden_dim, L/4)
        return x


class Decoder1D(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64, lstm_hidden_dim=128):
        super().__init__()
        self.deconv1 = nn.Conv1d(lstm_hidden_dim * 2, embed_dim * 2,
                                 kernel_size=3, stride=1, padding=1)
        self.deconv2 = nn.Conv1d(embed_dim * 2, embed_dim,
                                 kernel_size=3, stride=1, padding=1)
        self.deconv3 = nn.Conv1d(embed_dim, 1,
                                 kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.leaky_relu(self.deconv1(x), negative_slope=0.01)
        x = F.interpolate(x, scale_factor=2, mode='linear', align_corners=True)

        x = F.leaky_relu(self.deconv2(x), negative_slope=0.01)
        x = F.interpolate(x, scale_factor=2, mode='linear', align_corners=True)

        x = self.deconv3(x)  # (B, 1, L)
        return x  # keep (B, 1, L)


class MAE1D_Mask(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=64, lstm_hidden_dim=128):
        super().__init__()
        self.encoder = Encoder1D_Mask(embed_dim, hidden_dim, lstm_hidden_dim)
        self.decoder = Decoder1D(embed_dim, hidden_dim, lstm_hidden_dim)

    def forward(self, x,mask):
        x = self.encoder(x,mask)
        x = self.decoder(x)
        return x.squeeze(1)

# Fine- truning

In [8]:
from torch.utils.data import DataLoader, TensorDataset
from utlis.finetune_wesad_data import load_pre_wesad_dataset, process_wesad_dataset
from stress_downstream import *

In [9]:
class DownstreamClassifier(nn.Module):
    def __init__(self, encoder, num_classes=2):
        super(DownstreamClassifier, self).__init__()
        self.encoder = encoder

        # Freeze encoder parameters
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        # Classifier head after GAP
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),  # Adjust 256 to match your encoder's output channels
            nn.LeakyReLU(negative_slope = 0.01),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.squeeze(1)         # (B, L)

        # Pass through encoder
        z = self.encoder(x)  # -> (B, C, L)

        # Global Average Pooling over sequence dimension
        z = F.adaptive_avg_pool1d(z, 1)  # -> (B, C, 1)
        z = z.squeeze(-1)                # -> (B, C)

        logits = self.classifier(z)      # -> (B, num_classes)
        return logits

In [10]:
X_train, X_test, y_train, y_test = load_pre_wesad_dataset('/home/van/NamQuang/SSLModel/WESAD_700_10_10', 'S2.pkl')
train_dataloader = ECGClassificationDataset(X_train, y_train)
test_dataloader = ECGClassificationDataset(X_test, y_test)

train_size = int(0.8 * len(train_dataloader))
val_size = len(train_dataloader) - train_size
train_dataset, val_dataset = random_split(train_dataloader, [train_size, val_size])

train_loader_wesad = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader_wesad = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader_wesad = DataLoader(test_dataloader, batch_size=128, shuffle=False)

In [28]:
# Load model and pretrained weights
BASE_PATH = '/home/van/NamQuang/'
FILE_MODEL = 'SSLModel/runs/seq1000_rpeak5_1126_08092025/model_1000_5_1126_08092025.pth'
auto_model = MAE1D_Mask()
auto_model.load_state_dict(torch.load(BASE_PATH+FILE_MODEL, weights_only=True))
frozen_encoder = auto_model.encoder
model = DownstreamClassifier(frozen_encoder)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4)
train_evaluate_downstream_classifier(
    model = model,
    criterion = criterion, 
    optimizer = optimizer, 
    train_loader = train_loader_wesad,
    val_loader = val_loader_wesad, 
    test_loader = test_loader_wesad, num_epochs=100,device = 'cuda')

Epoch 0: Train Loss = 0.5089, Train Acc = 0.7573, Train F1 = 0.4743 | Val Loss = 0.4394, Val Acc = 0.8114, Val F1 = 0.5780
Epoch 1: Train Loss = 0.3702, Train Acc = 0.8507, Train F1 = 0.7066 | Val Loss = 0.3503, Val Acc = 0.8609, Val F1 = 0.7334
Epoch 2: Train Loss = 0.3069, Train Acc = 0.8736, Train F1 = 0.7741 | Val Loss = 0.3130, Val Acc = 0.8730, Val F1 = 0.7762
Epoch 3: Train Loss = 0.2785, Train Acc = 0.8854, Train F1 = 0.8067 | Val Loss = 0.2948, Val Acc = 0.8803, Val F1 = 0.7890
Epoch 4: Train Loss = 0.2633, Train Acc = 0.8918, Train F1 = 0.8212 | Val Loss = 0.2805, Val Acc = 0.8851, Val F1 = 0.8071
Epoch 5: Train Loss = 0.2476, Train Acc = 0.9021, Train F1 = 0.8416 | Val Loss = 0.2773, Val Acc = 0.8888, Val F1 = 0.8066
Epoch 6: Train Loss = 0.2415, Train Acc = 0.9011, Train F1 = 0.8412 | Val Loss = 0.2612, Val Acc = 0.8924, Val F1 = 0.8183
Epoch 7: Train Loss = 0.2289, Train Acc = 0.9075, Train F1 = 0.8516 | Val Loss = 0.2543, Val Acc = 0.9008, Val F1 = 0.8322
Epoch 8: Train L

(0.4666666666666667,
 0.4666666666666667,
 0.6619047619047619,
 0.7746478873239436)

In [10]:
loso_training('/home/van/NamQuang/SSLModel/WESAD_700_10_10', 
              model_path = 'runs/seq1000_rpeak5_1126_08092025/model_1000_5_1126_08092025.pth', 
              test_size = 1,
              filename = 'test_loso.csv')

***** Loop 0: ['S11.pkl'] *****
Epoch 0: Train Loss = 0.5385, Train Acc = 0.7562, Train F1 = 0.4698 | Val Loss = 0.4602, Val Acc = 0.7855, Val F1 = 0.4807
Epoch 1: Train Loss = 0.4193, Train Acc = 0.8175, Train F1 = 0.5951 | Val Loss = 0.3634, Val Acc = 0.8691, Val F1 = 0.7550
Epoch 2: Train Loss = 0.3520, Train Acc = 0.8575, Train F1 = 0.7261 | Val Loss = 0.3196, Val Acc = 0.8909, Val F1 = 0.8107
Epoch 3: Train Loss = 0.3192, Train Acc = 0.8720, Train F1 = 0.7762 | Val Loss = 0.2982, Val Acc = 0.8945, Val F1 = 0.8195
Epoch 4: Train Loss = 0.3008, Train Acc = 0.8787, Train F1 = 0.7902 | Val Loss = 0.2931, Val Acc = 0.8873, Val F1 = 0.8243
Epoch 5: Train Loss = 0.2943, Train Acc = 0.8820, Train F1 = 0.8012 | Val Loss = 0.2739, Val Acc = 0.9067, Val F1 = 0.8444
Epoch 6: Train Loss = 0.2830, Train Acc = 0.8869, Train F1 = 0.8106 | Val Loss = 0.2702, Val Acc = 0.8970, Val F1 = 0.8356
Epoch 7: Train Loss = 0.2734, Train Acc = 0.8902, Train F1 = 0.8200 | Val Loss = 0.2649, Val Acc = 0.8982, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 0: Train Loss = 0.5358, Train Acc = 0.7608, Train F1 = 0.4591 | Val Loss = 0.4831, Val Acc = 0.7600, Val F1 = 0.4318
Epoch 1: Train Loss = 0.4245, Train Acc = 0.8093, Train F1 = 0.5541 | Val Loss = 0.3883, Val Acc = 0.8291, Val F1 = 0.6751
Epoch 2: Train Loss = 0.3460, Train Acc = 0.8617, Train F1 = 0.7277 | Val Loss = 0.3219, Val Acc = 0.8606, Val F1 = 0.7575
Epoch 3: Train Loss = 0.3053, Train Acc = 0.8842, Train F1 = 0.7919 | Val Loss = 0.2874, Val Acc = 0.8752, Val F1 = 0.7939
Epoch 4: Train Loss = 0.2791, Train Acc = 0.8887, Train F1 = 0.8045 | Val Loss = 0.2647, Val Acc = 0.8824, Val F1 = 0.8124
Epoch 5: Train Loss = 0.2634, Train Acc = 0.8933, Train F1 = 0.8171 | Val Loss = 0.2578, Val Acc = 0.8836, Val F1 = 0.8064
Epoch 6: Train Loss = 0.2557, Train Acc = 0.8984, Train F1 = 0.8261 | Val Loss = 0.2376, Val Acc = 0.8970, Val F1 = 0.8493
Epoch 7: Train Loss = 0.2404, Train Acc = 0.9021, Train F1 = 0.8375 | Val Loss = 0.2251, Val Acc = 0.9006, Val F1 = 0.8494
Epoch 8: Train L

KeyboardInterrupt: 