# Finetune CSFM for Custom Tasks

This tutorial notebook demonstrates how to adapt the pretrained **CSFM model** to downstream tasks. 

Before running this notebook, ensure the following:

1. The pretrained model checkpoint (`.pth` file) is saved in:  `../pretrained/<checkpoint_file>.pth`

2. A local environment is installed and functional following `../README.md`

3. Your dataset has been converted to the required `.h5` format based on the `../datasets/prepare_data.ipynb` file and stored in:  ../datasets/<dataset_name>.h5
   

# Set up the available GPUs and helper functions


In [1]:
import os

# set the visible gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from network.model import CSFM_model
import network.lr_decay as lrd
import h5py
import numpy as np
import neurokit2 as nk
import scipy
from torch.optim.lr_scheduler import LambdaLR

import random
import time
import argparse

# Set the seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Custom meter class
class Meter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0.0
        self.count = 0

    def update(self, value, n=1):
        self.sum += value * n
        self.count += n

    def average(self):
        return self.sum / self.count

# Function to get the learning rate schedule
def get_lr_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return LambdaLR(optimizer, lr_lambda)

# Cosine LR schedule with warmup
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # Cosine decay after warmup
        return 0.5 * (1 + np.cos(np.pi * (current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda)


# Define the hyperparameters


In [2]:
# Hyperparameters

batch_size = 16
learning_rate = 1e-4
num_epochs = 20

# Prepare Training, Validation, and Test Datasets

In this demonstration, we use the **VTaC dataset** (Ventricular Tachycardia Alarm Benchmark) for false VT alarm detection. Each data sample contains:

- One 10-s ECG signal (electrocardiogram), sampling rate = 250
- One 10-s PPG signal (photoplethysmogram), sampling rate = 250

In [3]:
class HDF5Dataset(Dataset):
    def __init__(self, h5_file_path, fs, channels):
        self.h5_file_path = h5_file_path
        self.h5_file = h5py.File(h5_file_path, 'r')
        self.signals = self.h5_file['tracings']
        self.labels = self.h5_file['labels']
        self.subject_ids = None
        self.fs = fs
        self.channels = channels
        self.retry_wait = 1
        self.max_retries = 12


    def __len__(self):
        return len(self.signals)

    def __getitem__(self, idx):
        retries = 0
        while retries < self.max_retries:
            try:
                signal = self.signals[idx, :, :]

                if signal.shape[0] > signal.shape[1]:
                    signal = signal.T

                signal = scipy.signal.resample(signal, num=int(signal.shape[1] * 250 / self.fs), axis=1)
                for i, channel in zip(range(signal.shape[0]), self.channels):
                    if channel<12:
                        signal[i] = nk.ecg_clean(signal[i], sampling_rate=250)
                    else:
                        signal[i] = nk.ppg_clean(signal[i], sampling_rate=250)

                start_idx = random.randint(0, signal.shape[1] - 2500)
                signal_segment = signal[:, start_idx:start_idx + 2500]
                epsilon = 1e-8  # Small value to prevent division by zero
                mean = np.mean(signal_segment, axis=1, keepdims=True)
                std = np.std(signal_segment, axis=1, keepdims=True)
                signal_segment = (signal_segment - mean) / (std + epsilon)
                target = self.labels[idx].squeeze()

                return torch.tensor(signal_segment, dtype=torch.float32), torch.tensor(target, dtype=torch.long)
            except OSError as e:
                print(f"Error reading data at index {idx}: {e}. Retrying {retries + 1}/{self.max_retries}")
                retries += 1
                time.sleep(self.retry_wait)
        signal_segment = torch.randn(1, 2500)
        target = np.zeros(self.labels.shape[1])
        return torch.tensor(signal_segment, dtype=torch.float32), torch.tensor(target, dtype=torch.long)

    def close(self):
        self.h5_file.close()


# Load the dataset

train_dataset = HDF5Dataset("../datasets/vtac_train.h5", fs=250, channels=[1,12])
val_dataset = HDF5Dataset("../datasets/vtac_val.h5", fs=250, channels=[1,12])
test_dataset = HDF5Dataset("../datasets/vtac_test.h5", fs=250, channels=[1,12])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

multilabel = True

# Model initialization and load the checkpoint

In [None]:
model_type = 'Tiny'  # Options: 'CSFM-Tiny', 'CSFM-Base', 'CSFM-Large'

In [4]:
model = CSFM_model(model_type)

for i, layer in enumerate(model.transformer.layers):
    print(f"Layer {i} mlp_dim:", layer[1].net[1].out_features)

model = model.cuda()

# Load pretrained weights if available
checkpoint_path = '../pretrained/checkpoint.pth'
print('load from ', checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
encoder_state_dict = {k.replace('encoder.', ''): v for k, v in checkpoint.items() if k.startswith('encoder.') and 'mlp_head' not in k}

model.load_state_dict(encoder_state_dict, strict=False)

Layer 0 mlp_dim: 1024
Layer 1 mlp_dim: 1024
Layer 2 mlp_dim: 1024
Layer 3 mlp_dim: 1024
Layer 4 mlp_dim: 1024
Layer 5 mlp_dim: 1024
load from  ../pretrained/checkpoint.pth


_IncompatibleKeys(missing_keys=['mlp_head.weight', 'mlp_head.bias'], unexpected_keys=['dense_scratch.layer1_rn.weight', 'dense_scratch.layer2_rn.weight', 'dense_scratch.layer3_rn.weight', 'dense_scratch.layer4_rn.weight', 'dense_scratch.layer_rn.0.weight', 'dense_scratch.layer_rn.1.weight', 'dense_scratch.layer_rn.2.weight', 'dense_scratch.layer_rn.3.weight', 'dense_scratch.refinenet1.out_conv.weight', 'dense_scratch.refinenet1.out_conv.bias', 'dense_scratch.refinenet1.resConfUnit1.conv1.weight', 'dense_scratch.refinenet1.resConfUnit1.conv1.bias', 'dense_scratch.refinenet1.resConfUnit1.conv2.weight', 'dense_scratch.refinenet1.resConfUnit1.conv2.bias', 'dense_scratch.refinenet1.resConfUnit2.conv1.weight', 'dense_scratch.refinenet1.resConfUnit2.conv1.bias', 'dense_scratch.refinenet1.resConfUnit2.conv2.weight', 'dense_scratch.refinenet1.resConfUnit2.conv2.bias', 'dense_scratch.refinenet2.out_conv.weight', 'dense_scratch.refinenet2.out_conv.bias', 'dense_scratch.refinenet2.resConfUnit1.con

# Training loop

In [5]:
# Optimizer and loss function

# ASL loss function
class ASL(nn.Module):
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations'''

    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, **kwargs):
        super(ASL, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input prob (after sigmoid)
        y: targets (multi-label binarized vector)
        """

        self.targets = y
        self.anti_targets = 1 - y

        # Calculating Probabilities
        self.xs_pos = x
        self.xs_neg = 1.0 - self.xs_pos

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip).clamp_(max=1)

        # Basic CE calculation
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        return -self.loss.sum()

param_groups = lrd.param_groups_lrd(model,
    base_lr=learning_rate,
    weight_decay=5e-2,
    no_weight_decay_list=model.no_weight_decay(),
    layer_decay=0.75
)

optimizer = torch.optim.AdamW(param_groups)
criterion = ASL().cuda()

# Total number of training steps (number of batches)
total_steps = len(train_loader) * num_epochs
# Number of warmup steps
num_warmup_steps = int(0.1 * total_steps)

# Initialize the learning rate scheduler
scheduler = get_lr_schedule_with_warmup(optimizer, num_warmup_steps, total_steps)

# Custom meter for loss tracking
train_loss_meter = Meter()
val_loss_meter = Meter()

channels = np.asarray(train_dataset.channels)

best_val_f1 = 0.0
best_model_path = None


In [6]:
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve, auc

# Function to calculate F1 Score for each label
def calculate_f1_per_label(predictions, targets, multilabel=False):

    if multilabel:
        predictions = (predictions > 0.5).astype(int)
    else:
        predictions = np.argmax(predictions, axis=1)
    return f1_score(targets, predictions, average=None)

# Function to calculate ROC AUC Score
def calculate_auc(predictions, targets, multilabel=False):

    if multilabel:
        return roc_auc_score(targets, predictions, average='macro')
    else:
        return roc_auc_score(targets, predictions, multi_class='ovr')
    
# Function to calculate AUPR Score
def calculate_aupr(predictions, targets, multilabel=False):
    
    precision, recall, thresholds = precision_recall_curve(targets, predictions)
    aupr = auc(recall, precision)
    return aupr

In [7]:
# Training loop

for epoch in range(num_epochs):
    model.train()
    train_loss_meter.reset()

    all_train_predictions = []
    all_train_targets = []

    for batch_idx, batch in enumerate(train_loader):
        inputs, targets = batch
        inputs = inputs.cuda()
        targets = targets.cuda()

        optimizer.zero_grad()

        outputs = model(inputs, channels, task='cls').squeeze()
        
        if multilabel:
            outputs = torch.sigmoid(outputs)
        else:
            outputs = torch.softmax(outputs, dim=-1)

        loss = criterion(outputs, targets)
        loss.backward()

        train_loss_meter.update(loss.item(), inputs.size(0))
        optimizer.step()
        scheduler.step()  # Update the scheduler

        for i in range(len(outputs)):
            all_train_predictions.append(outputs[i].cpu().detach().numpy())
            all_train_targets.append(targets[i].cpu().detach().numpy())

        if batch_idx % 10 == 0:
            print(f'Epoch {epoch + 1}/{num_epochs} Batch {batch_idx + 1}/{len(train_loader)}: Train Loss: {loss.item():.4f}')

    train_loss = train_loss_meter.average()
    train_predictions = np.array(all_train_predictions)
    train_targets = np.array(all_train_targets)
    train_f1_per_label = calculate_f1_per_label(train_predictions, train_targets, multilabel=multilabel)

    model.eval()
    val_loss_meter.reset()
    all_val_predictions = []
    all_val_targets = []

    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.cuda()
            targets = targets.cuda()

            outputs = model(inputs, channels, task='cls').squeeze()
        
            if multilabel:
                outputs = torch.sigmoid(outputs)
            else:
                outputs = torch.softmax(outputs, dim=-1)

            loss = criterion(outputs, targets)
            val_loss_meter.update(loss.item(), inputs.size(0))
            
            for i in range(len(outputs)):
                all_val_predictions.append(outputs[i].cpu().detach().numpy())
                all_val_targets.append(targets[i].cpu().detach().numpy())

    val_loss = val_loss_meter.average()
    val_predictions = np.array(all_val_predictions)

    val_targets = np.array(all_val_targets)

    val_f1_per_label = calculate_f1_per_label(val_predictions, val_targets, multilabel=multilabel)
    val_f1_mean = np.mean(val_f1_per_label)


    print(f'Epoch {epoch + 1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train F1: {np.mean(train_f1_per_label)}'
          f'Val Loss: {val_loss:.4f}, Val F1: {np.mean(val_f1_per_label)}')

    if val_f1_mean > best_val_f1:
        best_val_f1 = val_f1_mean

        if best_model_path is not None:
            os.remove(best_model_path)  # Remove the previous best model

        best_model_path = os.path.join(f'best_model_epoch{epoch + 1}_valf1_{best_val_f1:.4f}.pth')
        best_model_state = model.state_dict()
        torch.save(best_model_state, best_model_path)
        print(f"Best model saved to {best_model_path}")

# Save the last model
last_model_path = os.path.join(f'last_model_epoch{epoch + 1}.pth')
if os.path.exists(last_model_path):
    os.remove(last_model_path)  # Remove the previous last model

torch.save(model.state_dict(), last_model_path)
print(f"Last model saved to {last_model_path}")

# Load for best model for testing
model.load_state_dict(torch.load(best_model_path))
print(f"Best model loaded from {best_model_path}")

Epoch 1/20 Batch 1/232: Train Loss: 1.0091
Epoch 1/20 Batch 11/232: Train Loss: 2.6165
Epoch 1/20 Batch 21/232: Train Loss: 2.2977
Epoch 1/20 Batch 31/232: Train Loss: 1.6573
Epoch 1/20 Batch 41/232: Train Loss: 3.0466
Epoch 1/20 Batch 51/232: Train Loss: 2.6906
Epoch 1/20 Batch 61/232: Train Loss: 2.0949
Epoch 1/20 Batch 71/232: Train Loss: 2.2861
Epoch 1/20 Batch 81/232: Train Loss: 2.7015
Epoch 1/20 Batch 91/232: Train Loss: 1.1174
Epoch 1/20 Batch 101/232: Train Loss: 1.5462
Epoch 1/20 Batch 111/232: Train Loss: 1.8139
Epoch 1/20 Batch 121/232: Train Loss: 1.5849
Epoch 1/20 Batch 131/232: Train Loss: 1.7915
Epoch 1/20 Batch 141/232: Train Loss: 1.5755
Epoch 1/20 Batch 151/232: Train Loss: 1.9806
Epoch 1/20 Batch 161/232: Train Loss: 2.3332
Epoch 1/20 Batch 171/232: Train Loss: 0.8863
Epoch 1/20 Batch 181/232: Train Loss: 1.6646
Epoch 1/20 Batch 191/232: Train Loss: 0.8447
Epoch 1/20 Batch 201/232: Train Loss: 0.8230
Epoch 1/20 Batch 211/232: Train Loss: 1.5349
Epoch 1/20 Batch 221/

  model.load_state_dict(torch.load(best_model_path))


# Load and evaluate the best finetuned checkpoint

In [10]:
from sklearn.metrics import accuracy_score, recall_score, precision_score

# Load for best model for testing
model.load_state_dict(torch.load(best_model_path, weights_only=True))
print(f"Best model loaded from {best_model_path}")

model.eval()
test_loss_meter = Meter()
all_test_predictions = []
all_test_targets = []

with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch
        inputs = inputs.cuda()
        targets = targets.cuda()
        
        outputs = model(inputs, channels, task='cls').squeeze()
        
        if multilabel:
            outputs = torch.sigmoid(outputs)
        else:
            outputs = torch.softmax(outputs, dim=-1)

        loss = criterion(outputs, targets)
        test_loss_meter.update(loss.item(), inputs.size(0))

        if not multilabel:
            outputs = torch.softmax(outputs, dim=-1)

        for i in range(len(outputs)):
            all_test_predictions.append(outputs[i].cpu().detach().numpy())
            all_test_targets.append(targets[i].cpu().detach().numpy())

test_loss = test_loss_meter.average()
test_predictions = np.array(all_test_predictions)
test_targets = np.array(all_test_targets)
test_f1_per_label = calculate_f1_per_label(test_predictions, test_targets, multilabel=multilabel)
test_f1_macro = np.mean(test_f1_per_label)

test_auc = calculate_auc(test_predictions, test_targets, multilabel=multilabel)
test_aupr = calculate_aupr(test_predictions, test_targets, multilabel=multilabel)


# Recall 
if multilabel:
    test_recall = recall_score(test_targets, (test_predictions > 0.5).astype(int), average='macro')
else:
    test_recall = recall_score(test_targets, np.argmax(test_predictions, axis=1), average='macro')

# Precision (for multilabel or multiclass)
if multilabel:
    test_precision = precision_score(test_targets, (test_predictions > 0.5).astype(int), average='macro')
else:
    test_precision = precision_score(test_targets, np.argmax(test_predictions, axis=1), average='macro')

# Accuracy
if multilabel:
    test_accuracy = accuracy_score(test_targets, (test_predictions > 0.5).astype(int))
else:
    test_accuracy = accuracy_score(test_targets, np.argmax(test_predictions, axis=1))

print(f'Test Loss: {test_loss:.4f}, Test F1 (macro): {test_f1_macro:.4f}, Test AUC: {test_auc:.4f}, Test AUPR: {test_aupr:.4f}, '
      f'Test Recall: {test_recall:.4f}, Test Precision: {test_precision:.4f}, Test Accuracy: {test_accuracy:.4f}')

Best model loaded from best_model_epoch13_valf1_0.8981.pth
Test Loss: 0.7863, Test F1 (macro): 0.8830, Test AUC: 0.9664, Test AUPR: 0.9106, Test Recall: 0.9067, Test Precision: 0.8671, Test Accuracy: 0.9025
