# Model DL: Wavelet Transform + CNN
Hypothesis: Wavelet transform can improve anomaly detection in ECG signals. So it can also improve classification of them.
But wavelets capture temporal characteristics that correlate with signal length, and we already established that different classes have different lengths. 
Also we should keep some type of short term locality to keep the local features that wavelets are good at capturing. So go easy on the pooling layers and use multiple kernel sizes. 

Consider: adaptive pooling or attention mechanisms to better handle variable-length signals.
--> for that case move transform inside the model

In [None]:
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import torch

train_df = pd.read_pickle("train_df.pkl")
val_df = pd.read_pickle("val_df.pkl")

INPUT_LENGTH = 9000
BATCH_SIZE = 32
EPOCHS = 10

# lengths after transformations!!!
expected_lengths = {
    "wavelet": 143,  # Only using approximate coefficients, reduces length significantly TODO: function that calculates this with dummy input
    "fourier": INPUT_LENGTH,
    None: INPUT_LENGTH,  #
}


def set_seed(seed=42):
    random.seed(seed)  # Python random
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)
print("Seed set for reproducibility.")

### Wavelet Transform
Following cell is to illustrate the wavelet transform and its effect on the signal characteristics. It could also be useful when selecting the mother wavelet and level of decomposition.

**It is not relevant to the model itself. You can skip just skip it.**

In [None]:
import pywt
import numpy as np
import matplotlib.pyplot as plt

t = np.linspace(0, 1, 256)
sample_signal = train_df.iloc[0]["signal"]  
sample_signal = np.array(sample_signal, dtype=np.float32)


wavelet = pywt.Wavelet("db2")
levels = 6

coeffs = pywt.wavedec(sample_signal, wavelet, level=levels)

print(f"Number of coefficient arrays: {len(coeffs)}")
print(f"Approximation coefficients shape: {coeffs[0].shape}")
for i in range(1, len(coeffs)):
    print(f"Detail level {i} coefficients shape: {coeffs[i].shape}")

# Plot the original signal
plt.figure(figsize=(15, 10))

plt.subplot(levels + 2, 1, 1)
plt.plot(sample_signal, "b-", linewidth=1)
plt.title("Original Signal")
plt.ylabel("Amplitude")


approx_only = pywt.waverec(
    [coeffs[0]] + [np.zeros_like(c) for c in coeffs[1:]], wavelet
)
plt.subplot(levels + 2, 1, 2)
plt.plot(approx_only, "r-", linewidth=1)
plt.plot(sample_signal, "b-", alpha=0.3, linewidth=0.5)
plt.title("Approximation Only (Lowest Frequencies)")
plt.ylabel("Amplitude")

# ---------------- Progressive Reconstruction ----------------
for level in range(1, levels + 1):

    partial_coeffs = [coeffs[0]]  

    for i in range(1, level + 1):
        partial_coeffs.append(coeffs[i])  # add details up to current level

    
    for i in range(level + 1, levels + 1):
        partial_coeffs.append(np.zeros_like(coeffs[i])) # add zeros for remaining detail levels

    reconstructed = pywt.waverec(partial_coeffs, wavelet)

    plt.subplot(levels + 2, 1, level + 2)
    plt.plot(reconstructed, "g-", linewidth=1, label=f"Approx + Details 1-{level}")
    plt.plot(sample_signal, "b-", alpha=0.3, linewidth=0.5, label="Original")
    plt.title(f"Reconstruction: Approximation + Detail Levels 1-{level}")
    plt.ylabel("Amplitude")
    if level == levels:
        plt.xlabel("Samples")

plt.tight_layout()
plt.show()

# --------------------- Detail Levels Only ---------------------
plt.figure(figsize=(15, 8))
for level in range(1, levels + 1):
    detail_coeffs = [np.zeros_like(coeffs[0])]  # Zero approximation

    for i in range(1, levels + 1):
        if i == level:
            detail_coeffs.append(coeffs[i])
        else:
            detail_coeffs.append(np.zeros_like(coeffs[i]))

    detail_only = pywt.waverec(detail_coeffs, wavelet)

    plt.subplot(levels, 1, level)
    plt.plot(detail_only, "r-", linewidth=1)
    plt.title(f"Detail Level {level} Only")
    plt.ylabel("Amplitude")
    if level == levels:
        plt.xlabel("Samples")

plt.tight_layout()
plt.show()

# --------------- Progressive Reconstruction Error ---------------
plt.figure(figsize=(12, 6))
errors = []
level_names = ["Approximation Only"]

# Error for approximation only
approx_only = pywt.waverec(
    [coeffs[0]] + [np.zeros_like(c) for c in coeffs[1:]], wavelet
)
error = np.mean((sample_signal - approx_only) ** 2)
errors.append(error)

# Error for progressive reconstruction
for level in range(1, levels + 1):
    partial_coeffs = [coeffs[0]]
    for i in range(1, level + 1):
        partial_coeffs.append(coeffs[i])
    for i in range(level + 1, levels + 1):
        partial_coeffs.append(np.zeros_like(coeffs[i]))

    reconstructed = pywt.waverec(partial_coeffs, wavelet)
    error = np.mean((sample_signal - reconstructed) ** 2)
    errors.append(error)
    level_names.append(f"+ Detail {level}")

plt.semilogy(errors, "bo-", linewidth=2, markersize=8)
plt.xticks(range(len(level_names)), level_names, rotation=45)
plt.ylabel("Mean Squared Error (log scale)")
plt.title("Reconstruction Error vs. Number of Detail Levels")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final reconstruction error: {errors[-1]:.2e}")

# 1.3 Data Augmentation

In [None]:
# Augmentation functions
import scipy.signal
import torch
import numpy as np

# wrapping them in nn.Module should not introduce huge overhead


def time_shift(signal, shift_range=(-100, 100)):
    shift = np.roll(signal, shift=random.randint(shift_range[0], shift_range[1]))
    return shift    


def add_noise(signal, noise_level=0.1):
    noise = np.random.normal(0, noise_level, signal.shape)
    return signal + noise


def time_warp(signal, warp_factor=0.1):
    return scipy.signal.resample(signal,int(len(signal) * (1 + np.random.uniform(-warp_factor, warp_factor))))


def amplitude_scaling(signal, scale_range=(0.8, 1.2)):
    scale = random.uniform(scale_range[0], scale_range[1])
    return signal * scale


def pad_or_trim(signal, target_length=INPUT_LENGTH):
    current_length = len(signal)

    if current_length < target_length:
        # Pad with zeros at the end
        padding = target_length - current_length
        signal = np.pad(signal, (0, padding), "constant")
    elif current_length > target_length:
        # Trim from center
        start = (current_length - target_length) // 2
        signal = signal[start : start + target_length]

    return signal


def pad_and_augment(signal, augmentation = "all"):
    if (augmentation == "all" or augmentation == "warp_only") and random.random() < 0.5:
        signal = time_warp(signal)

    signal = pad_or_trim(signal, INPUT_LENGTH)  # Ensure the signal is of the expected length

    for aug in [add_noise, time_shift, amplitude_scaling]:
        if augmentation == "all" or augmentation != "shift_only":
            if random.random() < 0.5 and aug != time_warp:
                signal = aug(signal)
    return signal

# generic dataset with adjustable preprocessing (data augmentation + wavelet/fourier transform)
def wavelet_transform(signal, wavelet="db2", levels=6):
    coeffs = pywt.wavedec(signal, wavelet, level=levels)
    # Return only the approximation coefficients
    return coeffs[0]

def fourier_transform(signal):
    # Apply Fourier Transform and return the absolute values of the coefficients
    coeffs = np.fft.fft(signal)
    return np.abs(coeffs)


class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, df, signal_transform = None, augmentation=None):
        self.df = df
        self.signal_transform = signal_transform
        self.augmentation = augmentation
        self.target_length = (
            expected_lengths[self.signal_transform]
            if self.signal_transform
            else INPUT_LENGTH
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        signal = self.df.iloc[idx]['signal']
        signal = np.array(signal, dtype=np.float32)
        label = int(self.df.iloc[idx]['label'])

        # keep it simple: time domain augmentation only
        signal = pad_and_augment(signal, self.augmentation) # has padding in it. the order of augmentations should be important (wrapping changes the shape but noise and scaling before padding makes it easier to distinguish the signal length)
       

        if self.signal_transform == "wavelet":
            signal = wavelet_transform(signal)
        elif self.signal_transform == "fourier":
            signal = fourier_transform(signal)

        signal = torch.tensor(signal, dtype=torch.float32) 
        label = torch.tensor(label, dtype=torch.long)

        return signal, label

Adding noise after padding otherwise the model can cheat by learning where the real data ends, and use only that to discriminate between classes (different classes in training data have different mean lengths).


In [None]:
from torch.utils.data import DataLoader

train_dataset = ECGDataset(train_df, signal_transform="wavelet", augmentation=None)
val_dataset = ECGDataset(val_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    generator=torch.Generator().manual_seed(42),
)
val_loader = DataLoader(val_dataset, batch_size=32)

In [None]:
sample_signal, sample_label = next(iter(train_loader))
print(f"Sample signal shape: {sample_signal.shape}")
print(f"Sample label shape: {sample_label.shape}")

##### A Question:

Conv is defined as Conv(x, k) ≡ IFFT(FFT(x) * FFT(k))  
But PyTorch does not compute the convolution this way.
So are there still opportunities for fusing etc. when calling FFT+Conv one after the other?

x is NxN, k is KxK, and N >> K.

Conv(FFT(x), k) = IFFT(FFT(FFT(x)) ⋅ FFT(k)) = N ⋅(circular_conv(reverse(x),k))

##### Complexities (Conv only)

normal spacial conv(x,k) -> O(N^2 * K^2)

FFTConv(x,k) -> O(N^2 * log(N))

Overlap and Add Conv [Highlander, 2016](https://arxiv.org/pdf/1601.06815) -> O(N^2 * log(K))



##### What I found

It is not worth it when K is small like typical CNN kernels. Could be considered for large kernels like 10x10 or larger.
[some benchmark](https://github.com/fkodom/fft-conv-pytorch?utm_source=chatgpt.com), [its blog post](https://fkodom.substack.com/p/fourier-convolutions-in-pytorch)

Paper implementing OaAConv variant-SplitConv in Verilog!: [paper](https://arxiv.org/pdf/2003.12621)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import tqdm

from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report

def he_init_weights(m):
    """He initialization for weights"""
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

class SimpleCNNwithoutClassifier(nn.Module):
    # remove classifier layer, use as feature extractor
    def __init__(self, input_channels=1, feature_size=50, input_length=INPUT_LENGTH):
        super(SimpleCNNwithoutClassifier, self).__init__()
        self.input_length = input_length
        self.input_channels = input_channels
        self.feature_size = feature_size
        self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2)
        conv_output_size = 32 * (input_length // 4)

        self.fc1 = nn.Linear(conv_output_size, feature_size)
        self.apply(he_init_weights)  

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  
        x = F.relu(self.fc1(x))
        return x

class SimpleCNN(nn.Module):

    def __init__(
        self,
        input_channels=1,
        num_classes=5,
        feature_size=128, # hidden_size
        input_length=INPUT_LENGTH,
    ):

        super(SimpleCNN, self).__init__()
        self.num_classes = num_classes
        self.cnn_feature_extractor = SimpleCNNwithoutClassifier(
            input_channels=input_channels, feature_size=feature_size, input_length=input_length
        )
        self.fc_classifier = nn.Linear(feature_size, self.num_classes)
        self.apply(he_init_weights)

    def forward(self, x):
        x = self.cnn_feature_extractor(x)
        x = x.view(x.size(0), -1)
        x = self.fc_classifier(x)
        return x

    def predict(self, x):
        """Make predictions using the trained CNN"""
        with torch.no_grad():
            x = self.forward(x)
            _, predicted = torch.max(x, 1)
        return predicted

class CNNwithSVM(nn.Module):

    def __init__(
        self,
        input_channels=1,
        feature_size=50,
        num_classes=5,   # not necessary, adding this so I have a consistent interface. For a larger project it would be better to have a dedicated model selector function.
        input_length=INPUT_LENGTH,
    ):
        super(CNNwithSVM, self).__init__()
        self.cnn_feature_extractor = SimpleCNNwithoutClassifier(
            input_channels=input_channels,
            feature_size=feature_size,
            input_length=input_length,
        )
        self.svm = SVC(kernel="rbf", C=1.0, gamma="scale", random_state=42)
        self.is_svm_trained = False
        self.apply(he_init_weights)

    def forward(self, x):
        x = self.cnn_feature_extractor(x)
        x = x.view(x.size(0), -1)
        return x

    def predict(self, x):
        """Make predictions using the trained SVM"""
        if not self.is_svm_trained:
            raise ValueError("SVM has not been trained yet!")

        with torch.no_grad():
            features = self.forward(x)
            features_np = features.cpu().numpy()
            predictions = self.svm.predict(features_np)
        return torch.tensor(predictions, dtype=torch.long)

In [None]:
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):

    set_seed(42)  # Ensure reproducibility
    print(f"Training model: {model.__class__.__name__}")
    device = "cuda" if torch.cuda.is_available() else "mps"
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    tqdm_train_loader = tqdm.tqdm(train_loader, desc="Training", leave=False)
    tqdm_val_loader = tqdm.tqdm(val_loader, desc="Validation", leave=False)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for signals, labels in tqdm_train_loader:
            signals, labels = signals.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(signals.unsqueeze(1))  # Add channel dimension
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(tqdm_train_loader)
        train_accuracy = correct / total

        # ------------------- Validation -------------------
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for signals, labels in tqdm_val_loader:
                signals, labels = signals.to(device), labels.to(device)
                outputs = model(signals.unsqueeze(1))
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(tqdm_val_loader)
        val_accuracy = val_correct / val_total

        print(
            f"Epoch [{epoch+1}/{epochs}], "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}"
        )

    return model


# GENERAL MODEL, DOES NOT REQUIRE SVM TRAINING
def evaluate_model(model, data_loader, device):

    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for signals, labels in tqdm.tqdm(data_loader, desc="Evaluating"):
            signals = signals.unsqueeze(1).to(device)
            predictions = model.predict(signals)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    return accuracy, all_predictions, all_labels

In [None]:
# to keep the training function in a managable state, not a necessary functionality
def extract_features_from_loader(model, data_loader, device):
    model.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for signals, labels in tqdm.tqdm(data_loader, desc="Extracting features"):
            signals = signals.unsqueeze(1).to(device) 
            features = model.cnn_feature_extractor(signals)

            all_features.append(features.cpu().numpy())
            all_labels.append(labels.numpy())

    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    return all_features, all_labels


def train_cnn_svm_model(
    model: CNNwithSVM, train_loader, val_loader, epochs=10, lr=0.001
):
    set_seed(42) 
    print(f"Training CNN with SVM: {model.__class__.__name__}")
    device = "cuda" if torch.cuda.is_available() else "mps"

    input_channels = model.cnn_feature_extractor.input_channels
    input_length = model.cnn_feature_extractor.input_length
    feature_size = model.cnn_feature_extractor.feature_size

    fc_classifier_model = SimpleCNN(
        input_channels=input_channels,
        feature_size=feature_size,
        input_length=input_length,
    )
    fc_classifier_model.to(device)

    trained_cnn = train_model(
        fc_classifier_model, train_loader, val_loader, epochs=epochs, lr=lr
    )

    model.cnn_feature_extractor = trained_cnn.cnn_feature_extractor
    model.to(device)

    train_features, train_labels = extract_features_from_loader(
        model, train_loader, device
    )

    print(
        f"Training SVM on {len(train_features)} samples with {train_features.shape[1]} features"
    )

    model.svm.fit(train_features, train_labels)
    model.is_svm_trained = True
    print("SVM trained successfully!")

    # ------------------- Validation -------------------
    val_features, val_labels = extract_features_from_loader(model, val_loader, device)

    svm_predictions = model.svm.predict(val_features)
    svm_accuracy = accuracy_score(val_labels, svm_predictions)

    print(f"SVM Validation Accuracy: {svm_accuracy:.4f}")
    print("\nDetailed Classification Report:")
    print(classification_report(val_labels, svm_predictions))

    return model

In [None]:
# Calculate actual input length after transformation
sample_signal = pad_or_trim(train_df.iloc[0]["signal"], target_length=9000)
sample_signal = fourier_transform(sample_signal)
print(len(sample_signal))  # 143 for wavelet 

# Grid Search for Hyperparameters

In [None]:
test_space = {
    "model": [SimpleCNN, CNNwithSVM],
    "signal_transform": [None, "wavelet", "fourier"],
    "augmentation": [
        "all",
        "shift_only",
        "noise_only",
        "warp_only",
        "scale_only",
        None,
    ],
    "mother_wavelet": ["db2", "haar"],  # For wavelet transform, TODO
    "levels": [6, 4],  # For wavelet transform, TODO
}

results = []
print("Starting tests with different configurations...")

for model in test_space["model"]:
    for signal_transform in test_space["signal_transform"]:
        for augmentation in test_space["augmentation"]:
            print(f"Testing {model.__name__} with {signal_transform} and {augmentation}")
            input_length = expected_lengths[signal_transform]
            model_instance = model(input_channels=1, feature_size=50, num_classes=5, input_length=input_length)
            model_instance = model_instance.to("cuda" if torch.cuda.is_available() else "mps")

            train_set = ECGDataset(train_df, signal_transform=signal_transform, augmentation=augmentation)
            val_set = ECGDataset(val_df, signal_transform=signal_transform)

            train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, generator=torch.Generator().manual_seed(42))
            val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)

            trained_model = train_cnn_svm_model(model_instance, train_loader, val_loader, epochs=1, lr=0.001) if model == CNNwithSVM else train_model(model_instance, train_loader, val_loader, epochs=1, lr=0.001)

            accuracy, _, _ = evaluate_model(
                trained_model, val_loader, "cuda" if torch.cuda.is_available() else "mps"
            )
            print(f"Validation Accuracy of {model.__name__} with {signal_transform} and {augmentation}: {accuracy:.4f}\n")
            results.append({
                "model": model.__name__,
                "signal_transform": signal_transform,
                "augmentation": augmentation,
                "accuracy": accuracy,
            })
            del trained_model, train_loader, val_loader, train_set, val_set
            torch.cuda.empty_cache()  # Clear GPU memory after each test

print("All tests completed.")
print("Best result found:")
best_result = max(results, key=lambda x: x["accuracy"])
print(f"Model: {best_result['model']}, Signal Transform: {best_result['signal_transform']}, Augmentation: {best_result['augmentation']}, Accuracy: {best_result['accuracy']:.4f}")

##### Scrapped ideas
- mixture of experts with 3 modalities (time domain, frequency domain, wavelet domain)-- could be out of scope
- CWT