# Enhancing Non-Patient Specific ECG Classification Through Convolutional Denoising Autoencoders

This code implements the paper: **Arrhythmia Detection from 2-lead ECG using Convolutional Denoising Autoencoders** written by Ochiai, Takahashi’, and Fukazawa. We mirror the methodology from the paper by first downloading the data from the MIT-BIH Ayyhythmia and NSRDB datasets.

This work is submitted by: **Jay Mittal, Patrick Dowell, Alex Vo, & Joshua Barraza**.


WFDB in Python refers to the Waveform Database Software Package (WFDB) for Python, an open-source library that provides tools for reading, writing, processing, and plotting physiological signals and associated annotations.

In [None]:
!pip install torch

In [None]:
!pip install wfdb

In [None]:
# =====================================================
# 1. SETUP + IMPORTS
# =====================================================
import os
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

torch.manual_seed(42)
np.random.seed(42)

In [None]:
# ============================
# 2. CONFIGURATION
# ============================
class Config:
    # ----- Data paths -----
    DATA_ROOT = "./data"
    X_FILE = "X.npy"          # shape: (N, TIME_STEPS, NUM_LEADS)
    Y_FILE = "y.npy"          # shape: (N,)
    GROUP_FILE = "groups.npy" # optional: shape (N,) with patient/record IDs

    # ----- Splits -----
    TRAIN_VAL_TEST_SPLIT = (0.7, 0.15, 0.15)  # used if GROUP_FILE missing

    # ----- Input shape -----
    TIME_STEPS = 360          # samples per segment, chose 360 to match 360 Hz of MIT-BIH
    NUM_LEADS = 2             # # of ECG leads in X
    NUM_CLASSES = 5           # # of arrhythmia classes

    # ----- Training hyperparameters -----
    BATCH_SIZE = 64
    NUM_WORKERS = 2

    LR_BASELINE = 1e-3
    LR_AE = 1e-3
    LR_CLASSIFIER = 1e-3          # CDAE classifier learning rate (frozen phase)
    LR_CLASSIFIER_FINE_TUNE = 5e-4  # CDAE classifier + encoder fine-tune LR

    NUM_EPOCHS_BASELINE = 5   # TODO: Switch later on to 20, used 5 for quick testing
    NUM_EPOCHS_AE = 5
    NUM_EPOCHS_CDAE = 5

    # ----- Denoising AE noise schedule -----
    AE_NOISE_STD_MIN = 0.01
    AE_NOISE_STD_MAX = 0.15

    # ----- CDAE training schedule -----
    CDAE_FREEZE_EPOCHS = 5   # epochs with frozen encoder at start


cfg = Config()
os.makedirs(cfg.DATA_ROOT, exist_ok=True)



**Fixed-Length Window Extraction**

To transform raw ECG signals into a format suitable for deep learning, each annotated beat is converted into a 1-second window centered on the annotation index:
- Sampling Frequency: 360 Hz
- Window Size: 360 samples
- Leads: 2

Windows are extracted as (360, 2) arrays and stacked into a dataset of shape (N, 360, 2) where N is the number of beats kept after preprocessing.

**Beat-class Mapping**

Following our project proposal, we simplified annotations into 5 classes:

**[Symbol(s)]	[Class]	[Description]** (TODO: Find a better way to format this)

[N, L, R]	[0]	[Normal / bundle branch block]

[V]	[1]	[Ventricular ectopic]

[A]	[2]	[Atrial / supraventricular]

[/]	[3]	[Paced]

[F]	[4]	[Fusion / other]

In [None]:
# =====================================================
# 2.5 PREPROCESSING (LIGHTWEIGHT):
# Build X.npy, y.npy, groups.npy from existing MITDB
# =====================================================
import os
import numpy as np
import wfdb

DATA_ROOT = cfg.DATA_ROOT if hasattr(cfg, "DATA_ROOT") else "./data"
MITDB_DIR = os.path.join(DATA_ROOT, "mitdb")

os.makedirs(DATA_ROOT, exist_ok=True)

X_PATH = os.path.join(DATA_ROOT, "X.npy")
Y_PATH = os.path.join(DATA_ROOT, "y.npy")
G_PATH = os.path.join(DATA_ROOT, "groups.npy")

# ---------- 1) Skip if already done ----------
if os.path.exists(X_PATH) and os.path.exists(Y_PATH) and os.path.exists(G_PATH):
    print("Found existing X.npy, y.npy, groups.npy — skipping preprocessing.")
else:
    print("No preprocessed arrays found. Building X.npy, y.npy, groups.npy from MITDB only.")

    # ---------- 2) Basic parameters ----------
    FS = 360                  # MIT-BIH sampling frequency
    WINDOW_SECONDS = 1.0      # 1-second windows
    TIME_STEPS = int(FS * WINDOW_SECONDS)  # 360

    assert TIME_STEPS == cfg.TIME_STEPS, (
        f"TIME_STEPS from preprocessing = {TIME_STEPS}, "
        f"but cfg.TIME_STEPS = {cfg.TIME_STEPS}. "
        "Adjust one of them so they match."
    )

    # Mapping of beat symbols -> class index
    symbol_to_class = {
        'N': 0,  # normal-related (N, LBBB, RBBB, etc.)
        'L': 0,
        'R': 0,
        'V': 1,  # ventricular
        'A': 2,  # supraventricular / atrial
        '/': 3,  # paced
        'F': 4,  # fusion / other
    }

    # To avoid OOM, cap beats per record
    MAX_BEATS_PER_RECORD = 500   # tweak this if you want more/less

    # ---------- 3) Helper to find record IDs ----------
    def load_record_list_from_hea(db_dir: str):
        files = os.listdir(db_dir)
        hea_files = [f for f in files if f.endswith(".hea")]
        record_ids = sorted([os.path.splitext(f)[0] for f in hea_files])
        print(f"Found {len(record_ids)} records in {db_dir}")
        return record_ids

    # ---------- 4) Extract segments from one record ----------
    def extract_segments_from_record(record_id: str, db_dir: str):
        """
        Returns segments: (K, TIME_STEPS, 2)
                labels:   (K,)
                groups:   (K,)
        """
        rec_path = os.path.join(db_dir, record_id)

        record = wfdb.rdrecord(rec_path)
        ann = wfdb.rdann(rec_path, 'atr')

        sig = record.p_signal  # (num_samples, num_leads)
        num_samples, num_leads = sig.shape
        if num_leads < 2:
            return None, None, None

        half_win = TIME_STEPS // 2
        segments = []
        labels = []
        groups = []

        # We’ll shuffle indices so we only keep up to MAX_BEATS_PER_RECORD
        beat_indices = np.arange(len(ann.sample))
        np.random.shuffle(beat_indices)

        kept = 0
        for idx_i in beat_indices:
            if kept >= MAX_BEATS_PER_RECORD:
                break

            idx = ann.sample[idx_i]
            sym = ann.symbol[idx_i]
            if sym not in symbol_to_class:
                continue

            start = idx - half_win
            end = idx + half_win

            if start < 0 or end > num_samples:
                continue

            seg = sig[start:end, :2]  # (T, 2)
            if seg.shape[0] != TIME_STEPS:
                continue

            segments.append(seg.astype(np.float32))
            labels.append(symbol_to_class[sym])
            groups.append(record_id)
            kept += 1

        if len(segments) == 0:
            return None, None, None

        return (
            np.stack(segments, axis=0),           # (K, T, 2)
            np.array(labels, dtype=np.int64),     # (K,)
            np.array(groups)                      # (K,)
        )

    # ---------- 5) Loop over MITDB records ----------
    print("=== Extracting labeled segments from MIT-BIH Arrhythmia (mitdb) ===")
    mit_records = load_record_list_from_hea(MITDB_DIR)

    all_X = []
    all_y = []
    all_groups = []

    for rec in mit_records:
        print(f"Processing record {rec}...")
        X_rec, y_rec, g_rec = extract_segments_from_record(rec, MITDB_DIR)
        if X_rec is None:
            print(f"  No usable beats for record {rec}, skipping.")
            continue
        all_X.append(X_rec)
        all_y.append(y_rec)
        all_groups.append(g_rec)

    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    groups = np.concatenate(all_groups, axis=0)

    print("\nFinal dataset shapes:")
    print("X:", X.shape, "y:", y.shape, "groups:", groups.shape)
    print("Unique labels:", np.unique(y))

    np.save(X_PATH, X)
    np.save(Y_PATH, y)
    np.save(G_PATH, groups)

    print("\nSaved X.npy, y.npy, groups.npy in", DATA_ROOT)
    print("Preprocessing done.")


In [None]:
# =====================================================
# 3. DATASET
# =====================================================
class ECGDataset(Dataset):
    """
    Expects:
        X.npy: shape (N, TIME_STEPS, NUM_LEADS)
        y.npy: shape (N,)
        groups.npy (optional): shape (N,) with group IDs (e.g., record or patient).

    If groups.npy exists, we perform group-wise splitting so that no group appears
    in more than one split (patient-independent evaluation).

    Otherwise, we fall back to a simple random split.
    """
    def __init__(self, split: str = "train"):
        assert split in ["train", "val", "test"]
        self.split = split

        X, y, groups = self.load_arrays()

        if groups is not None:
            # Group-wise split (e.g., by record or patient ID)
            unique_groups = np.unique(groups)
            np.random.shuffle(unique_groups)

            n_groups = len(unique_groups)
            n_train = int(cfg.TRAIN_VAL_TEST_SPLIT[0] * n_groups)
            n_val = int(cfg.TRAIN_VAL_TEST_SPLIT[1] * n_groups)
            # remaining groups go to test
            train_groups = set(unique_groups[:n_train])
            val_groups = set(unique_groups[n_train:n_train + n_val])
            test_groups = set(unique_groups[n_train + n_val:])

            if split == "train":
                mask = np.isin(groups, list(train_groups))
            elif split == "val":
                mask = np.isin(groups, list(val_groups))
            else:
                mask = np.isin(groups, list(test_groups))

            self.X = X[mask]
            self.y = y[mask]
        else:
            # Sample-wise random split
            n = len(X)
            n_train = int(cfg.TRAIN_VAL_TEST_SPLIT[0] * n)
            n_val = int(cfg.TRAIN_VAL_TEST_SPLIT[1] * n)
            n_test = n - n_train - n_val

            indices = np.arange(n)
            np.random.shuffle(indices)

            train_idx = indices[:n_train]
            val_idx = indices[n_train:n_train + n_val]
            test_idx = indices[n_train + n_val:]

            if split == "train":
                self.X = X[train_idx]
                self.y = y[train_idx]
            elif split == "val":
                self.X = X[val_idx]
                self.y = y[val_idx]
            else:
                self.X = X[test_idx]
                self.y = y[test_idx]

        # Safety check on shapes
        assert self.X.shape[1] == cfg.TIME_STEPS, \
            f"Expected TIME_STEPS={cfg.TIME_STEPS}, got {self.X.shape[1]}"
        assert self.X.shape[2] == cfg.NUM_LEADS, \
            f"Expected NUM_LEADS={cfg.NUM_LEADS}, got {self.X.shape[2]}"

    def load_arrays(self) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
        """
        Loads X, y (and optionally groups) from disk.

        You must create these files in a separate preprocessing step.
        """
        x_path = os.path.join(cfg.DATA_ROOT, cfg.X_FILE)
        y_path = os.path.join(cfg.DATA_ROOT, cfg.Y_FILE)
        assert os.path.exists(x_path), f"Missing {x_path}"
        assert os.path.exists(y_path), f"Missing {y_path}"

        X = np.load(x_path)  # (N, TIME_STEPS, NUM_LEADS)
        y = np.load(y_path)  # (N,)

        groups_path = os.path.join(cfg.DATA_ROOT, cfg.GROUP_FILE)
        if os.path.exists(groups_path):
            groups = np.load(groups_path)
        else:
            groups = None

        return X, y, groups

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx: int):
        """
        Returns:
            x: torch.Tensor, shape (NUM_LEADS, TIME_STEPS)
               (leads as channels, time as sequence)
            y: torch.LongTensor, scalar class label
        """
        x = self.X[idx].astype(np.float32)  # (T, L)
        y = int(self.y[idx])

        # Per-segment, per-lead normalization (z-score)
        # shape (T, L)
        mean = x.mean(axis=0, keepdims=True)
        std = x.std(axis=0, keepdims=True) + 1e-8
        x = (x - mean) / std

        # Convert to (C, T) for Conv1d (channels = leads)
        x = torch.from_numpy(x).permute(1, 0)  # (L, T)

        label = torch.tensor(y, dtype=torch.long)
        return x, label


def get_dataloaders():
    train_ds = ECGDataset(split="train")
    val_ds = ECGDataset(split="val")
    test_ds = ECGDataset(split="test")

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )
    return train_loader, val_loader, test_loader
    

**Baseline Architecture**

Our baseline classifier is a lightweight 1D convolutional neural network, consisting of:
- Three Conv1D → ReLU → BatchNorm blocks
- Global Average Pooling
- Two fully connected layers
- Softmax output over 5 classes
This model directly maps raw ECG windows (360 × 2) to arrhythmia classes.

**Baseline Training Procedure**

- Loss: Cross-Entropy
- Optimizer: Adam (1e-3)
- Batch Size: 64
- Epochs (Testing Runs): 5
- Epochs (Actual Run): 20
- Device: CPU

We train the baseline on the training set and evaluate on validation and test sets.

In [19]:
# =====================================================
# 4. BASELINE CNN CLASSIFIER
# =====================================================
class BaselineCNN(nn.Module):
    """
    Chose to setup as a Simple 1D CNN baseline
    Input: (batch, C=NUM_LEADS, T=TIME_STEPS)
    """
    def __init__(self, num_leads: int, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(num_leads, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)  # -> (B, 128, 1)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        logits = self.classifier(x)
        return logits
        

**CDAE Motivation**

The CDAE improves feature robustness by learning to reconstruct clean ECG signals from intentionally corrupted inputs. This unsupervised pre-training has the goal of being able to:
- Improve generalization
- Denoise heartbeat morphology
- Provide a more stable representation than raw numerical values

**Denoising Strategy**

Our strategy for denoising by adding Gaussian noise with randomly sampled standard deviation uses the following equation:

$x_{noisy}=x+N(0,\sigma^2), \sigma \in [0.01,0.15]$


In [20]:
# =====================================================
# 5. CDAE: AUTOENCODER + CLASSIFIER HEAD
# =====================================================
class Encoder(nn.Module):
    def __init__(self, num_leads: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(num_leads, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
            # Output: (B, 128, T/4)
        )

    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    def __init__(self, num_leads: int, input_time: int):
        super().__init__()
        self.input_time = input_time
        self.net = nn.Sequential(
            nn.ConvTranspose1d(
                128, 64, kernel_size=4, stride=2, padding=1
            ),  # upsample x2
            nn.ReLU(),
            nn.ConvTranspose1d(
                64, 32, kernel_size=4, stride=2, padding=1
            ),  # upsample x2
            nn.ReLU(),
            nn.Conv1d(32, num_leads, kernel_size=7, padding=3)
        )

    def forward(self, z):
        x_hat = self.net(z)
        # Adjust to exact TIME_STEPS if off by 1
        if x_hat.shape[-1] > self.input_time:
            x_hat = x_hat[..., :self.input_time]
        elif x_hat.shape[-1] < self.input_time:
            pad = self.input_time - x_hat.shape[-1]
            x_hat = nn.functional.pad(x_hat, (0, pad))
        return x_hat


class CDAEClassifier(nn.Module):
    def __init__(self, encoder: Encoder, num_classes: int):
        super().__init__()
        self.encoder = encoder
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        z = self.encoder(x)
        z_pooled = self.pool(z)
        logits = self.fc(z_pooled)
        return logits
        

In [21]:
# =====================================================
# 6. TRAINING UTILITIES
# =====================================================
def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor:
    """
    Adds Gaussian noise with std drawn uniformly from [AE_NOISE_STD_MIN, AE_NOISE_STD_MAX].
    """
    std_min = cfg.AE_NOISE_STD_MIN
    std_max = cfg.AE_NOISE_STD_MAX
    if std_max <= 0:
        return x
    noise_std = np.random.uniform(std_min, std_max)
    return x + torch.randn_like(x) * noise_std


def train_epoch_classifier(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def eval_classifier(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = model(x)
        loss = criterion(logits, y)

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total


def train_autoencoder(encoder: Encoder, decoder: Decoder,
                      train_loader, val_loader):
    encoder.to(device)
    decoder.to(device)
    ae_params = list(encoder.parameters()) + list(decoder.parameters())
    optimizer = optim.Adam(ae_params, lr=cfg.LR_AE)
    criterion = nn.MSELoss()

    for epoch in range(cfg.NUM_EPOCHS_AE):
        encoder.train()
        decoder.train()
        train_loss = 0.0
        total = 0

        for x, _ in train_loader:
            x = x.to(device, non_blocking=True)
            noisy_x = add_gaussian_noise(x)

            optimizer.zero_grad()
            z = encoder(noisy_x)
            x_hat = decoder(z)
            loss = criterion(x_hat, x)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * x.size(0)
            total += x.size(0)

        train_loss /= total

        # Validation reconstruction loss
        encoder.eval()
        decoder.eval()
        val_loss = 0.0
        val_total = 0

        with torch.no_grad():
            for x, _ in val_loader:
                x = x.to(device, non_blocking=True)
                noisy_x = add_gaussian_noise(x)
                z = encoder(noisy_x)
                x_hat = decoder(z)
                loss = criterion(x_hat, x)
                val_loss += loss.item() * x.size(0)
                val_total += x.size(0)

        val_loss /= val_total
        print(f"[AE] Epoch {epoch+1}/{cfg.NUM_EPOCHS_AE} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        

In [22]:
# =====================================================
# 7. END-TO-END EXPERIMENTS
# =====================================================
def run_baseline_experiment(train_loader, val_loader, test_loader):
    model = BaselineCNN(cfg.NUM_LEADS, cfg.NUM_CLASSES).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=cfg.LR_BASELINE)

    print("=== Training Baseline CNN ===")
    for epoch in range(cfg.NUM_EPOCHS_BASELINE):
        train_loss, train_acc = train_epoch_classifier(
            model, train_loader, criterion, optimizer
        )
        val_loss, val_acc = eval_classifier(model, val_loader, criterion)

        print(f"[Baseline] Epoch {epoch+1}/{cfg.NUM_EPOCHS_BASELINE} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    test_loss, test_acc = eval_classifier(model, test_loader, criterion)
    print(f"[Baseline] Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    return model


def run_cdae_experiment(train_loader, val_loader, test_loader):
    # 1) Pretrain Autoencoder
    encoder = Encoder(cfg.NUM_LEADS)
    decoder = Decoder(cfg.NUM_LEADS, cfg.TIME_STEPS)

    print("=== Pretraining CDAE Autoencoder ===")
    train_autoencoder(encoder, decoder, train_loader, val_loader)

    # 2) Attach classifier head
    model = CDAEClassifier(encoder, cfg.NUM_CLASSES).to(device)
    criterion = nn.CrossEntropyLoss()

    # Phase 1: freeze encoder, train classifier head
    for param in model.encoder.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=cfg.LR_CLASSIFIER
    )

    print("=== Training CDAE Classifier (frozen encoder) ===")
    for epoch in range(cfg.CDAE_FREEZE_EPOCHS):
        train_loss, train_acc = train_epoch_classifier(
            model, train_loader, criterion, optimizer
        )
        val_loss, val_acc = eval_classifier(model, val_loader, criterion)

        print(f"[CDAE-Frozen] Epoch {epoch+1}/{cfg.CDAE_FREEZE_EPOCHS} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    # Phase 2: unfreeze encoder and fine-tune full network
    for param in model.encoder.parameters():
        param.requires_grad = True

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=cfg.LR_CLASSIFIER_FINE_TUNE
    )

    print("=== Fine-tuning CDAE Classifier (encoder + head) ===")
    for epoch in range(cfg.CDAE_FREEZE_EPOCHS, cfg.NUM_EPOCHS_CDAE):
        train_loss, train_acc = train_epoch_classifier(
            model, train_loader, criterion, optimizer
        )
        val_loss, val_acc = eval_classifier(model, val_loader, criterion)

        print(f"[CDAE-Fine] Epoch {epoch+1}/{cfg.NUM_EPOCHS_CDAE} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    test_loss, test_acc = eval_classifier(model, test_loader, criterion)
    print(f"[CDAE] Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    return model
    

In [None]:
# =====================================================
# 8. MAIN ENTRY POINT (for scripts) 
# =====================================================
if __name__ == "__main__":
    train_loader, val_loader, test_loader = get_dataloaders()

    # Baseline CNN experiment
    baseline_model = run_baseline_experiment(train_loader, val_loader, test_loader)

    # CDAE experiment
    cdae_model = run_cdae_experiment(train_loader, val_loader, test_loader)

=== Training Baseline CNN ===
[Baseline] Epoch 1/5 | Train Loss: 0.0433 Acc: 0.9879 | Val Loss: 0.0661 Acc: 0.9770
[Baseline] Epoch 2/5 | Train Loss: 0.0156 Acc: 0.9955 | Val Loss: 0.0443 Acc: 0.9868
[Baseline] Epoch 3/5 | Train Loss: 0.0116 Acc: 0.9968 | Val Loss: 0.0620 Acc: 0.9794
[Baseline] Epoch 4/5 | Train Loss: 0.0094 Acc: 0.9974 | Val Loss: 0.0775 Acc: 0.9726
[Baseline] Epoch 5/5 | Train Loss: 0.0083 Acc: 0.9977 | Val Loss: 0.0323 Acc: 0.9931
[Baseline] Test Loss: 0.0311 | Test Acc: 0.9939
=== Pretraining CDAE Autoencoder ===
[AE] Epoch 1/5 | Train Loss: 0.0183 | Val Loss: 0.0085
[AE] Epoch 2/5 | Train Loss: 0.0069 | Val Loss: 0.0062
