### import the dependencies

In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from torchdiffeq import odeint

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

DATA_PATH = os.path.join("..", "Dataset", "raw_RRI_segments.csv")
OUT_MODEL = os.path.join("..", "Two_Class_Models/saved_models", "NODE_raw_two_class_best.pth")

### Load PSR dataset and preprocess

In [2]:
df = pd.read_csv(DATA_PATH)

# Keep only classes 0 (SR) and 2 (AF)
mask_binary = df["label"].isin([0, 2])
df = df[mask_binary].copy()

# Remap labels: 0 → 0 (SR), 2 → 1 (AF)
df["label"] = df["label"].map({0: 0, 2: 1})

# Pick raw RRI columns (prefix 'r_')
feature_cols = [c for c in df.columns if c.startswith("r_")]
meta_cols = ["patient_id", "record_id", "label", "label_str", "orig_len"]

assert len(feature_cols) > 0, "No r_ columns found in CSV"

# Extract feature matrix and labels
X = df[feature_cols].values.astype(np.float32)
y = df["label"].values.astype(int)

# Drop rows with NaN or Inf
mask_good = np.isfinite(X).all(axis=1)
X = X[mask_good]
y = y[mask_good]

print("Loaded raw samples:", X.shape, "class counts:", np.bincount(y))
print("Unique labels:", np.unique(y))  


Loaded raw samples: (9497, 50) class counts: [4750 4747]
Unique labels: [0 1]


### Train / val / test split, scaling and SMOTE (apply SMOTE only to training set)

In [3]:
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=RANDOM_SEED
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.2, stratify=y_temp, random_state=RANDOM_SEED
)

print("Splits -> train:", X_train.shape, "val:", X_val.shape, "test:", X_test.shape)

# Standardize using training stats
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# SMOTE on training only
smote = SMOTE(random_state=RANDOM_SEED)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)
print("After SMOTE train distribution:", np.bincount(y_train_res))

Splits -> train: (6077, 50) val: (1520, 50) test: (1900, 50)
After SMOTE train distribution: [3040 3040]


### Compute class weights

In [4]:
# compute class weights from resampled training set (used as alpha for focal loss)
cw = compute_class_weight(class_weight='balanced', classes=np.unique(y_train_res), y=y_train_res)
class_weights_tensor = torch.tensor(cw, dtype=torch.float32)
print("Class weights:", cw)

Class weights: [1. 1.]


### Compute focal loss

In [5]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce)
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

### Define model class

In [6]:
class ODEFunc(nn.Module):
    def __init__(self, dim):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 64),
            nn.Tanh(),
            nn.Linear(64, dim)
        )

    def forward(self, t, x):
        return self.net(x)

class NODEModel(nn.Module):
    def __init__(self, dim, num_classes):
        super(NODEModel, self).__init__()
        self.odefunc = ODEFunc(dim)
        self.classifier = nn.Sequential(
            nn.Linear(dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # x: (batch, features) — treat as initial state
        t = torch.tensor([0.0, 1.0], dtype=x.dtype, device=x.device)
        out = odeint(self.odefunc, x, t)[-1]
        return self.classifier(out)

### Train model

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [8]:
import itertools
from sklearn.metrics import f1_score

def train_with_hyperparameter_tuning(model_class, input_dim, num_classes,
                                     train_dataset, val_dataset,
                                     class_weights_tensor,
                                     param_grid,
                                     save_path=OUT_MODEL,
                                     epochs=8,
                                     patience=2,
                                     device=None):
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    class_weights_tensor = class_weights_tensor.to(device)

    best_f1 = 0.0
    best_params = None
    best_state = None

    combos = list(itertools.product(*param_grid.values()))
    print(f"Total combinations: {len(combos)}")

    for combo in combos:
        params = dict(zip(param_grid.keys(), combo))
        print("\nTrying params:", params)

        # recreate loaders with chosen batch size
        batch_size = params["batch_size"]
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        model = model_class(dim=input_dim, num_classes=num_classes).to(device)
        if params["optimizer"] == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
        elif params["optimizer"] == "sgd":
            optimizer = torch.optim.SGD(model.parameters(), lr=params["lr"], momentum=0.9)
        else:
            raise ValueError("Unsupported optimizer")

        criterion = FocalLoss(alpha=class_weights_tensor, gamma=params.get("gamma", 2.0))
        epochs_no_improve = 0
        local_best = 0.0

        for epoch in range(1, epochs + 1):
            model.train()
            running_loss = 0.0
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                out = model(xb)
                loss = criterion(out, yb)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * xb.size(0)
            train_loss = running_loss / len(train_loader.dataset)

            # validation
            model.eval()
            preds, trues = [], []
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    out = model(xb)
                    _, p = torch.max(out, dim=1)
                    preds.extend(p.cpu().numpy()); trues.extend(yb.cpu().numpy())

            val_f1 = f1_score(trues, preds, average="weighted")
            print(f"  epoch {epoch} train_loss: {train_loss:.4f} val_f1: {val_f1:.4f}")

            if val_f1 > local_best:
                local_best = val_f1
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print("  early stopping")
                    break

        # end epochs for this combo
        if local_best > best_f1:
            best_f1 = local_best
            best_params = params
            best_state = model.state_dict().copy()
            torch.save(best_state, save_path)
            print("  -> New global best. Saved model.")

    print(f"\nGlobal best val F1: {best_f1:.4f} params: {best_params}")
    return best_params, best_f1, best_state

# Build datasets (reuse existing X_train_res, y_train_res, X_val, y_val)
train_dataset = TensorDataset(torch.from_numpy(X_train_res.astype(np.float32)), torch.from_numpy(y_train_res.astype(np.int64)))
val_dataset = TensorDataset(torch.from_numpy(X_val.astype(np.float32)), torch.from_numpy(y_val.astype(np.int64)))
test_dataset = TensorDataset(torch.from_numpy(X_test.astype(np.float32)), torch.from_numpy(y_test.astype(np.int64)))

input_dim = X_train_res.shape[1]
num_classes = int(np.unique(y).size)

# Recommended small grid to start — tune and then expand if needed
param_grid = {
    "lr": [1e-3, 5e-4],
    "batch_size": [16, 32],
    "optimizer": ["adam", "sgd"]
    }

best_params, best_f1, best_state = train_with_hyperparameter_tuning(
    model_class=NODEModel,
    input_dim=input_dim,
    num_classes=num_classes,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    class_weights_tensor=class_weights_tensor,
    param_grid=param_grid,
    save_path=OUT_MODEL,
    epochs=8,
    patience=2
)

print("Best params:", best_params, "Best val F1:", best_f1)

Total combinations: 8

Trying params: {'lr': 0.001, 'batch_size': 16, 'optimizer': 'adam'}
  epoch 1 train_loss: 0.0999 val_f1: 0.8460
  epoch 2 train_loss: 0.0909 val_f1: 0.8644
  epoch 3 train_loss: 0.0862 val_f1: 0.8558
  epoch 4 train_loss: 0.0823 val_f1: 0.8650
  epoch 5 train_loss: 0.0798 val_f1: 0.8651
  epoch 6 train_loss: 0.0765 val_f1: 0.8644
  epoch 7 train_loss: 0.0739 val_f1: 0.8664
  epoch 8 train_loss: 0.0711 val_f1: 0.8631
  -> New global best. Saved model.

Trying params: {'lr': 0.001, 'batch_size': 16, 'optimizer': 'sgd'}
  epoch 1 train_loss: 0.1040 val_f1: 0.8361
  epoch 2 train_loss: 0.0982 val_f1: 0.8361
  epoch 3 train_loss: 0.0964 val_f1: 0.8413
  epoch 4 train_loss: 0.0952 val_f1: 0.8440
  epoch 5 train_loss: 0.0943 val_f1: 0.8492
  epoch 6 train_loss: 0.0934 val_f1: 0.8498
  epoch 7 train_loss: 0.0926 val_f1: 0.8512
  epoch 8 train_loss: 0.0919 val_f1: 0.8473

Trying params: {'lr': 0.001, 'batch_size': 32, 'optimizer': 'adam'}
  epoch 1 train_loss: 0.0997 val_

### Final evaluation on test set

In [9]:
# Load best model for final evaluation
if best_state is not None:
    best_model = NODEModel(dim=input_dim, num_classes=num_classes)
    best_model.load_state_dict(best_state)
    best_model = best_model.to(device)
else:
    # fallback: instantiate a fresh model
    best_model = NODEModel(dim=input_dim, num_classes=num_classes).to(device)

# ensure test_loader exists (create if needed)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Final evaluation
best_model.eval()
test_preds, test_trues = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        out = best_model(xb)
        _, p = torch.max(out, dim=1)
        test_preds.extend(p.cpu().numpy())
        test_trues.extend(yb.numpy())

print("Test class counts:", np.bincount(test_trues))
print("Accuracy:", accuracy_score(test_trues, test_preds))
print("F1 (weighted):", f1_score(test_trues, test_preds, average="weighted"))
print("\nClassification report:\n", classification_report(test_trues, test_preds, target_names=["SR", "AF"]))
print("\nConfusion matrix:\n", confusion_matrix(test_trues, test_preds))

Test class counts: [950 950]
Accuracy: 0.8689473684210526
F1 (weighted): 0.8689412329829291

Classification report:
               precision    recall  f1-score   support

          SR       0.86      0.88      0.87       950
          AF       0.87      0.86      0.87       950

    accuracy                           0.87      1900
   macro avg       0.87      0.87      0.87      1900
weighted avg       0.87      0.87      0.87      1900


Confusion matrix:
 [[832 118]
 [131 819]]
