### import the dependencies

In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from torchdiffeq import odeint

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

DATA_PATH = os.path.join("..", "Dataset", "raw_RRI_segments.csv")
OUT_MODEL = os.path.join("..", "Two_Class_Models", "saved_models", "PSR_hybrid_two_class_best.pth")

### Load PSR dataset and preprocess

In [2]:
# Load CSV
df = pd.read_csv(DATA_PATH)
print("CSV loaded. Shape:", df.shape)

# Keep only classes 0 (SR) and 2 (AF)
mask_binary = df["label"].isin([0, 2])
df = df[mask_binary].copy()

# Remap 0->0, 2->1 for binary classification
df["label"] = df["label"].map({0: 0, 2: 1})

print("Filtered classes 0 & 2. Shape:", df.shape)
print("Class distribution:\n", df["label"].value_counts())



CSV loaded. Shape: (14205, 55)
Filtered classes 0 & 2. Shape: (9497, 55)
Class distribution:
 label
0    4750
1    4747
Name: count, dtype: int64


### PSR

In [3]:
def phase_space_reconstruct(x, m=3, tau=1):
    """
    x: 1D array of RRI
    m: embedding dimension
    tau: time delay
    Returns flattened PSR embedding: x(t), x(t+tau), ..., x(t+(m-1)*tau)
    """
    x = np.asarray(x)
    N = len(x)
    if N < (m-1)*tau + 1:
        # pad with zeros if too short
        x = np.pad(x, (0, (m-1)*tau + 1 - N), 'constant')
        N = len(x)
    psr_vectors = [x[i:N-(m-1)*tau + i] for i in range(m)]
    psr_flat = np.column_stack(psr_vectors).flatten()
    return psr_flat

# Pick RRI columns
rri_cols = [f"r_{i}" for i in range(50)]

# Compute PSR for all rows with m=3, tau=2
psr_features = df[rri_cols].apply(lambda row: phase_space_reconstruct(row.values, m=3, tau=2), axis=1)
psr_features = np.stack(psr_features.values)  # shape: (num_samples, m*(N-(m-1)*tau))

# Add PSR columns to DataFrame
num_psr_cols = psr_features.shape[1]
psr_col_names = [f"psr_{i}" for i in range(num_psr_cols)]
df_psr = pd.DataFrame(psr_features, columns=psr_col_names, index=df.index)
df = pd.concat([df, df_psr], axis=1)

print("PSR features added (m=3, tau=2). DataFrame shape:", df.shape)




PSR features added (m=3, tau=2). DataFrame shape: (9497, 193)


### Extract features and labels

In [4]:
feature_cols = psr_col_names
meta_cols = ["patient_id", "record_id", "label", "label_str", "orig_len"]

# Feature matrix X and labels y
X = df[feature_cols].values.astype(np.float32)
y = df["label"].values.astype(int)

# Drop rows with NaN or inf
mask_good = np.isfinite(X).all(axis=1)
X = X[mask_good]
y = y[mask_good]

print("Loaded samples:", X.shape, "class counts:", np.bincount(y))


The history saving thread hit an unexpected error (OperationalError('database is locked')).History will not be written to the database.
Loaded samples: (9497, 138) class counts: [4750 4747]


### Train / val / test split, scaling and SMOTE (apply SMOTE only to training set)

In [5]:
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=RANDOM_SEED
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.2, stratify=y_temp, random_state=RANDOM_SEED
)

print("Splits -> train:", X_train.shape, "val:", X_val.shape, "test:", X_test.shape)

# Standardize using training stats
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# SMOTE on training only
smote = SMOTE(random_state=RANDOM_SEED)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)
print("After SMOTE train distribution:", np.bincount(y_train_res))

Splits -> train: (6077, 138) val: (1520, 138) test: (1900, 138)
After SMOTE train distribution: [3040 3040]


### Compute class weights

In [6]:
# compute class weights from resampled training set (used as alpha for focal loss)
cw = compute_class_weight(class_weight='balanced', classes=np.unique(y_train_res), y=y_train_res)
class_weights_tensor = torch.tensor(cw, dtype=torch.float32)
print("Class weights:", cw)

Class weights: [1. 1.]


### Compute focal loss

In [7]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce)
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

### Define model class

In [8]:
# --- Model definitions (Hybrid NODE + Attention) ---
class ODEFunc(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 64),
            nn.Tanh(),
            nn.Linear(64, dim)
        )
    def forward(self, t, x):
        return self.net(x)

class SelfAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.scale = dim ** 0.5
    def forward(self, x):
        # x: [batch, dim] -> [batch, 1, dim]
        x1 = x.unsqueeze(1)
        Q = self.query(x1)
        K = self.key(x1)
        V = self.value(x1)
        scores = torch.softmax(torch.bmm(Q, K.transpose(1,2)) / self.scale, dim=-1)
        out = torch.bmm(scores, V)  # [batch,1,dim]
        return out.squeeze(1)

class HybridNODEAttentionModel(nn.Module):
    def __init__(self, dim, num_classes):
        super().__init__()
        self.odefunc = ODEFunc(dim)
        self.attn = SelfAttention(dim)
        self.classifier = nn.Sequential(
            nn.Linear(dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    def forward(self, x):
        # x: [batch, dim]
        t = torch.tensor([0.0, 1.0], dtype=x.dtype, device=x.device)
        ode_out = odeint(self.odefunc, x, t)[-1]
        attn_out = self.attn(ode_out)
        return self.classifier(attn_out)


### Prepare DataLoaders

In [9]:
BATCH_SIZE = 32
train_ds = TensorDataset(torch.from_numpy(X_train_res.astype(np.float32)), torch.from_numpy(y_train_res.astype(np.int64)))
val_ds = TensorDataset(torch.from_numpy(X_val.astype(np.float32)), torch.from_numpy(y_val.astype(np.int64)))
test_ds = TensorDataset(torch.from_numpy(X_test.astype(np.float32)), torch.from_numpy(y_test.astype(np.int64)))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = X_train_res.shape[1]
num_classes = int(np.unique(y).size)

### Train model (hyperparameter tuning)

In [10]:
import itertools
from sklearn.metrics import f1_score

def train_with_hyperparameter_tuning(model_class, input_dim, num_classes,
                                     train_dataset, val_dataset,
                                     class_weights_tensor,
                                     param_grid,
                                     save_path=OUT_MODEL,
                                     epochs=8,
                                     patience=2,
                                     device=None):
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    class_weights_tensor = class_weights_tensor.to(device)

    best_f1 = 0.0
    best_params = None
    best_state = None

    combos = list(itertools.product(*param_grid.values()))
    print(f"Total combinations: {len(combos)}")

    for combo in combos:
        params = dict(zip(param_grid.keys(), combo))
        print("\nTrying params:", params)

        # recreate loaders with chosen batch size
        batch_size = params["batch_size"]
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        model = model_class(dim=input_dim, num_classes=num_classes).to(device)
        if params["optimizer"] == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
        elif params["optimizer"] == "sgd":
            optimizer = torch.optim.SGD(model.parameters(), lr=params["lr"], momentum=0.9)
        else:
            raise ValueError("Unsupported optimizer")

        criterion = FocalLoss(alpha=class_weights_tensor, gamma=params.get("gamma", 2.0))
        epochs_no_improve = 0
        local_best = 0.0

        for epoch in range(1, epochs + 1):
            model.train()
            running_loss = 0.0
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                out = model(xb)
                loss = criterion(out, yb)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * xb.size(0)
            train_loss = running_loss / len(train_loader.dataset)

            # validation
            model.eval()
            preds, trues = [], []
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    out = model(xb)
                    _, p = torch.max(out, dim=1)
                    preds.extend(p.cpu().numpy()); trues.extend(yb.cpu().numpy())

            val_f1 = f1_score(trues, preds, average="weighted")
            print(f"  epoch {epoch} train_loss: {train_loss:.4f} val_f1: {val_f1:.4f}")

            if val_f1 > local_best:
                local_best = val_f1
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print("  early stopping")
                    break

        # end epochs for this combo
        if local_best > best_f1:
            best_f1 = local_best
            best_params = params
            best_state = model.state_dict().copy()
            torch.save(best_state, save_path)
            print("  -> New global best. Saved model.")

    print(f"\nGlobal best val F1: {best_f1:.4f} params: {best_params}")
    return best_params, best_f1, best_state

# Build datasets (reuse existing X_train_res, y_train_res, X_val, y_val)
train_dataset = TensorDataset(torch.from_numpy(X_train_res.astype(np.float32)), torch.from_numpy(y_train_res.astype(np.int64)))
val_dataset = TensorDataset(torch.from_numpy(X_val.astype(np.float32)), torch.from_numpy(y_val.astype(np.int64)))
test_dataset = TensorDataset(torch.from_numpy(X_test.astype(np.float32)), torch.from_numpy(y_test.astype(np.int64)))

input_dim = X_train_res.shape[1]
num_classes = int(np.unique(y).size)

# Recommended small grid to start â€” tune and then expand if needed
param_grid = {
    "lr": [1e-3, 5e-4],
    "batch_size": [16, 32],
    "optimizer": ["adam", "sgd"]
}

best_params, best_f1, best_state = train_with_hyperparameter_tuning(
    model_class=HybridNODEAttentionModel,
    input_dim=input_dim,
    num_classes=num_classes,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    class_weights_tensor=class_weights_tensor,
    param_grid=param_grid,
    save_path=OUT_MODEL,
    epochs=8,
    patience=2,
    device=device
)

print("Best params:", best_params, "Best val F1:", best_f1)


Total combinations: 8

Trying params: {'lr': 0.001, 'batch_size': 16, 'optimizer': 'adam'}
  epoch 1 train_loss: 0.1008 val_f1: 0.8513
  epoch 2 train_loss: 0.0941 val_f1: 0.8463
  epoch 3 train_loss: 0.0891 val_f1: 0.8539
  epoch 4 train_loss: 0.0867 val_f1: 0.8612
  epoch 5 train_loss: 0.0838 val_f1: 0.8710
  epoch 6 train_loss: 0.0799 val_f1: 0.8674
  epoch 7 train_loss: 0.0767 val_f1: 0.8767
  epoch 8 train_loss: 0.0729 val_f1: 0.8737
  -> New global best. Saved model.

Trying params: {'lr': 0.001, 'batch_size': 16, 'optimizer': 'sgd'}
  epoch 1 train_loss: 0.1053 val_f1: 0.8335
  epoch 2 train_loss: 0.0967 val_f1: 0.8380
  epoch 3 train_loss: 0.0954 val_f1: 0.8367
  epoch 4 train_loss: 0.0942 val_f1: 0.8394
  epoch 5 train_loss: 0.0931 val_f1: 0.8446
  epoch 6 train_loss: 0.0923 val_f1: 0.8505
  epoch 7 train_loss: 0.0914 val_f1: 0.8498
  epoch 8 train_loss: 0.0905 val_f1: 0.8498
  early stopping

Trying params: {'lr': 0.001, 'batch_size': 32, 'optimizer': 'adam'}
  epoch 1 train_

### Final evaluation on test set

In [11]:
# --- Recreate model and load best weights ---
model = HybridNODEAttentionModel(dim=input_dim, num_classes=num_classes).to(device)

if best_state is not None:
    model.load_state_dict(best_state)
    print("Loaded best model state from training.")
else:
    model.load_state_dict(torch.load(OUT_MODEL, map_location=device))
    print("Loaded model state from file.")

# --- Prepare test data ---
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# --- Evaluate ---
model.eval()
test_preds, test_trues = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        out = model(xb)
        _, p = torch.max(out, dim=1)
        test_preds.extend(p.cpu().numpy())
        test_trues.extend(yb.numpy())

# --- Print metrics ---
print("Test class counts:", np.bincount(test_trues))
print("Accuracy:", accuracy_score(test_trues, test_preds))
print("F1 (weighted):", f1_score(test_trues, test_preds, average="weighted"))
print("\nClassification report:\n", classification_report(test_trues, test_preds, target_names=["SR","AF"]))
print("\nConfusion matrix:\n", confusion_matrix(test_trues, test_preds))


Loaded best model state from training.
Test class counts: [950 950]
Accuracy: 0.88
F1 (weighted): 0.879999468141687

Classification report:
               precision    recall  f1-score   support

          SR       0.88      0.88      0.88       950
          AF       0.88      0.88      0.88       950

    accuracy                           0.88      1900
   macro avg       0.88      0.88      0.88      1900
weighted avg       0.88      0.88      0.88      1900


Confusion matrix:
 [[834 116]
 [112 838]]
