# Training VAE for Anomaly Detection

The goal is to train Variational Autoencoders for anomaly detection.

## Dataset

ECG Heartbeat Categorization Dataset: https://www.kaggle.com/datasets/shayanfazeli/heartbeat?resource=download


In [None]:
# Imports
import pytorch_lightning as pl
import pandas as pd
from pathlib import Path
from torch import Tensor, FloatTensor, nn
from pandarallel import pandarallel
from typing import Optional
from torch.utils.data import (
    Dataset,
    DataLoader,
    Subset,
    ConcatDataset,
)
from torch.nn.utils.rnn import (
    pad_sequence,
    pack_sequence,
    pad_packed_sequence,
    pack_padded_sequence,
)
import torch
import os
import numpy as np
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import shutil
import torch.nn.functional as F
from datetime import timedelta
from tqdm import tqdm
from typing import List
from datetime import datetime
from torchmetrics import Accuracy, ConfusionMatrix, AUROC, ROC
import matplotlib.pyplot as plt
import random

In [None]:
# Initial setup for reproducibility
SEED_VALUE = 500
pl.seed_everything(SEED_VALUE)


# Constants
WINDOW_LENGTH = timedelta(minutes=10)
WINDOW_STRIDE = timedelta(minutes=10)
PH_THRESHOLD = 7.05
MAX_DATETIME_DIFFERENCE_CTG = timedelta(days=270)
MIN_MAX_PARAMS = [0, 240, 0, 127]
MINIMUM_VALID_WINDOWS = 6

In [None]:
# Initialize Parallel Apply
pandarallel.initialize(progress_bar=True)
tqdm.pandas()

In [None]:
# Load the test time series dataset
ecg_dataset = Path("../ecg/ptbdb_abnormal.csv")
df = pd.read_csv(ecg_dataset, header=None)
df.head()

In [None]:
# Load the test time series dataset
ecg_dataset_abnormal = Path("../ecg/ptbdb_abnormal.csv")
df_abnormal = pd.read_csv(ecg_dataset_abnormal, header=None)
df_abnormal.head()
ecg_row = df_abnormal.iloc[2]
print(ecg_row.tail(1).values[0])
ecg_row = ecg_row[:-1]
print(ecg_row)

In [None]:
# Visualize test time series
# Generic function to visualize CTGs
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from datetime import timedelta
from typing import List


def visualize_ecg(
    ecg: List,
    show_points: bool = False,
):
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    if show_points:
        mode = "lines+markers"
    else:
        mode = "lines"

    marker = dict(color="#FFFFFF", size=5, line=dict(color="black", width=1))
    ecg = go.Scatter(
        x=[*range(len(ecg))],
        y=ecg,
        name=f"ECG",
        mode=mode,
        marker=marker,
        line=dict(color="red", width=2),
    )
    fig.append_trace(ecg, row=1, col=1)
    return fig

In [None]:
random_ecg = df.sample().values.flatten().tolist()
fig = visualize_ecg(random_ecg)
fig.show()

In [None]:
import math
from datetime import datetime
from typing import List


class ECGDataset(Dataset):
    """Class representing the ECG dataset."""

    def __init__(self, data_dir: Path):
        self.data_dir = data_dir
        self.normal_df = pd.read_csv(data_dir / "ptbdb_normal.csv", header=None)
        self.normal_last_row = self.normal_df.shape[0] - 1
        self.abnormal_df = pd.read_csv(data_dir / "ptbdb_abnormal.csv", header=None)
        self.df = pd.concat([self.normal_df, self.abnormal_df])
        self.df = self.df.reset_index(drop=True)
        self.df = self.df.rename(columns={int(f"{len(df.columns)-1}"): "label"})
        self._targets = torch.tensor(self.df["label"].values)

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        ecg_row: pd.Series = self.df.iloc[index]
        label = torch.tensor([ecg_row.tail(1).values[0]])
        ecg_row = ecg_row[:-1]
        ecg = ecg_row.values.flatten().tolist()
        ecg_tensor = torch.from_numpy(np.array([ecg])).float()
        ecg_tensor = ecg_tensor.permute(1, 0)
        return ecg_tensor, label

    @property
    def normal_idx(self) -> Tensor:
        return [*range(self.normal_last_row + 1)]

    @property
    def abnormal_idx(self) -> Tensor:
        return [*range(self.normal_last_row + 1, self.df.shape[0])]

    @property
    def targets(self) -> Tensor:
        return self._targets

In [None]:
ecg_dataset = ECGDataset(data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg/"))
ecg, label = ecg_dataset.__getitem__(0)
print(ecg.shape)
print(label)
visualize_ecg(ecg.cpu().flatten().numpy().tolist())

In [None]:
# Padding Function
def pad_collate(batch):
    (xx, yy) = zip(*batch)

    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=0)

    return xx_pad, yy_pad, x_lens, y_lens


class ECGDataModule(pl.LightningDataModule):
    """Datamodule class to load the CTG dataset."""

    def __init__(
        self,
        data_dir: Path = Path("../ecg"),
        batch_size: int = 32,
        split_seed: int = 50,
        num_workers: int = 1,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def setup(self, stage: Optional[str] = None) -> None:
        self.full_dataset = ECGDataset(
            self.hparams.data_dir,
        )
        self.normal_dataset = Subset(self.full_dataset, self.full_dataset.normal_idx)
        self.abnormal_dataset = Subset(
            self.full_dataset, self.full_dataset.abnormal_idx
        )

        non_test_indexes, test_indexes = train_test_split(
            np.arange(len(self.normal_dataset)),
            test_size=0.20,
            random_state=self.hparams.split_seed,
            shuffle=True,
        )

        self.data_test_normal = Subset(self.normal_dataset, test_indexes)
        # Extract val samples from remaining
        train_indexes, val_indexes = train_test_split(
            non_test_indexes,
            test_size=0.25,
            random_state=self.hparams.split_seed,
            shuffle=True,
        )

        # Train and val datasets just containing normal samples
        self.data_train = Subset(self.normal_dataset, train_indexes)

        # Single batch experiment
        #         random_single_batch_indexes = random.sample(sorted(train_indexes), self.hparams.batch_size)
        #         self.random_single_batch = Subset(self.normal_dataset, random_single_batch_indexes)

        self.data_val = Subset(self.normal_dataset, val_indexes)

        # Test dataset containing just abnormal samples
        self.data_test = ConcatDataset([self.data_test_normal, self.abnormal_dataset])

    def train_dataloader(self):
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            collate_fn=pad_collate,
        )

    def val_dataloader(self):
        return DataLoader(
            self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            collate_fn=pad_collate,
        )

    def test_dataloader(self):
        return DataLoader(
            self.data_test,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            collate_fn=pad_collate,
        )

In [None]:
# LSTM Encoder
class LSTMEncoder(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=128, num_layers=2):
        super(LSTMEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # setup LSTM layer
        # Input size (N,L,Hin)
        # Output size (N,L,Hout)
        # Hidden size (num_layers, N, Hout)
        # Cell size (num_layers, N, Hout)
        self.lstm = nn.LSTM(
            self.input_dim, self.hidden_dim, self.num_layers, batch_first=True
        )

    def forward(self, input, hidden=None):
        outputs, (hidden, cell) = self.lstm(input, hidden)
        return (hidden, cell)


class LSTMDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=2, num_layers=2):
        super(LSTMDecoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim

        # setup LSTM layer
        # Input size (N,L,Hin)
        # Output size (N,L,Hout)
        self.lstm = nn.LSTM(
            self.input_dim, self.hidden_dim, self.num_layers, batch_first=True
        )
        # Output size (N,L,Oout)
        self.linear = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, input, hidden=None):
        output, (hidden, cell) = self.lstm(input, hidden)
        prediction = self.linear(output)
        return prediction, (hidden, cell)

In [None]:
# Pytorch Lightning Model
class ECG_LSTM_VAE(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.hidden_dim = 32
        self.latent_dim = 64
        self.num_layers = 2
        self.beta = 0.1
        self.features = 1

        self.lr = 1e-3

        self.lstm_enc = LSTMEncoder(
            input_dim=self.features,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )

        self.lstm_dec = LSTMDecoder(
            input_dim=self.latent_dim,
            hidden_dim=self.hidden_dim,
            output_dim=self.features,
            num_layers=self.num_layers,
        )

        self.fc_mu = nn.Linear(self.hidden_dim * self.num_layers, self.latent_dim)
        self.fc_var = nn.Linear(self.hidden_dim * self.num_layers, self.latent_dim)

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)

        z = mu + (noise * std)
        return z

    def calculate_and_log_metrics(self, loss_dict, step_type, batch_size):
        self.log(
            f"{step_type}_loss",
            loss_dict["loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )
        self.log(
            f"{step_type}_kld_loss",
            loss_dict["kld"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )
        self.log(
            f"{step_type}_recon_loss",
            loss_dict["recon_loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )

    def forward(self, x, x_lens):
        batch_size, seq_len, _ = x.shape

        # Pack x
        x_packed = pack_padded_sequence(
            x, x_lens, batch_first=True, enforce_sorted=False
        )

        # encode input space to hidden space
        enc_hidden = self.lstm_enc(x_packed)
        enc_h = enc_hidden[0].view(batch_size, self.hidden_dim * self.num_layers)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc_mu(enc_h)
        logvar = self.fc_var(enc_h)
        z = self.reparametize(mean, logvar)

        # decode latent space to input space
        z = z.repeat(1, seq_len, 1)
        z = z.view(batch_size, seq_len, self.latent_dim)
        x_hat, _ = self.lstm_dec(z, enc_hidden)
        x_hat = torch.flip(x_hat, dims=(1,))

        # x_hat, _ = self.lstm_dec(z)

        return mean, logvar, x_hat

    def loss_function(self, x_hat, x, mean, logvar, batch_size) -> dict:
        # recons_loss = F.mse_loss(x_hat, x)
        # kld_loss = torch.mean(
        #     -0.5 * torch.sum(1 + logvar - mean**2 - logvar.exp(), dim=1), dim=0
        # )

        recons_loss = torch.sum(
            torch.sum(F.mse_loss(x_hat, x, reduction="none"), dim=(1, 2))
        )
        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + logvar - mean**2 - logvar.exp(), dim=1), dim=0
        )

        # To account for minibatches
        kld_weight = x.shape[0] / batch_size

        loss = recons_loss + (self.beta * kld_weight * kld_loss)

        return {
            "loss": loss,
            "recon_loss": recons_loss,
            "kld": -kld_loss,
        }

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=(self.lr))
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        x, _, x_lens, _ = batch
        batch_size = x.shape[0]
        mean, logvar, x_hat = self.forward(x, x_lens)
        loss_dict = self.loss_function(x_hat, x, mean, logvar, batch_size)
        self.calculate_and_log_metrics(loss_dict, "train", batch_size)
        return loss_dict

    def validation_step(self, batch, batch_idx):
        x, _, x_lens, _ = batch
        batch_size = x.shape[0]
        mean, logvar, x_hat = self.forward(x, x_lens)
        loss_dict = self.loss_function(x_hat, x, mean, logvar, batch_size)
        self.calculate_and_log_metrics(loss_dict, "val", batch_size)
        # Plot and log first signal
        if batch_idx == 0:
            if self.current_epoch == 0:
                self.plot_and_save_signal(
                    x, x_lens, f"val_orig_{self.current_epoch+1}_{self.device.index}"
                )
            self.plot_and_save_signal(
                x_hat, x_lens, f"val_recon_{self.current_epoch+1}_{self.device.index}"
            )
        return loss_dict

    def plot_and_save_signal(self, x: Tensor, x_len: List[int], filename: str):
        first_ecg = x[0][: x_len[0], :]
        fig = visualize_ecg(first_ecg.cpu().flatten().numpy().tolist())
        saved_filename = Path(f"../tmp/{filename}.png")
        fig.write_image(saved_filename)
        self.logger.experiment.log_artifact(self.logger.run_id, saved_filename)
        saved_filename.unlink()

In [None]:
def train_model_with_hyperparams(datamodule, num_epochs=20):
    # Clear out the previous models
    models_folder = Path(f"../models/classification/lstm")
    if models_folder.exists():
        shutil.rmtree(models_folder)
    models_folder.mkdir(exist_ok=True, parents=True)

    mlflow_runs_folder = Path("../mlruns")
    mlflow_runs_folder.mkdir(exist_ok=True)

    temp_artifacts_folder = Path("../tmp")
    temp_artifacts_folder.mkdir(exist_ok=True)

    callbacks = []

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=models_folder,
        filename="models-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=10)

    callbacks.append(checkpoint_callback)
    # callbacks.append(early_stopping_callback)

    mlf_logger = MLFlowLogger(
        experiment_name="ecg_anomaly_lstm",
        tracking_uri=mlflow_runs_folder.absolute().as_uri(),
    )
    ctg_lstm_classifier = ECG_LSTM_VAE()
    trainer = pl.Trainer(
        logger=mlf_logger,
        max_epochs=num_epochs,
        callbacks=callbacks,
        num_sanity_val_steps=0,
        accelerator="gpu",
        devices=1,
        # strategy="ddp_notebook",
    )
    trainer.fit(model=ctg_lstm_classifier, datamodule=datamodule)
    return trainer, ctg_lstm_classifier


def test_model(trainer, datamodule):
    ctg_lstm_classifier = ECG_LSTM_VAE()
    trainer.test(model=ctg_lstm_classifier, datamodule=datamodule)

In [None]:
ecg_datamodule = ECGDataModule(
    data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg/"),
    split_seed=SEED_VALUE,
    num_workers=16,
    batch_size=32,
)
ecg_datamodule.prepare_data()
ecg_datamodule.setup()

In [None]:
# Try training the model with some hyperparameters.

trainer, model = train_model_with_hyperparams(datamodule=ecg_datamodule, num_epochs=100)
# test_model(trainer=trainer, datamodule=ctg_lstm_datamodule)

In [None]:
# Predict on a series of models
from typing import List

ecg_model = ECG_LSTM_VAE()
model = ECG_LSTM_VAE.load_from_checkpoint(
    "../models/classification/lstm/models-epoch=99-val_loss=107.76.ckpt"
)
model.eval()


def visualize_long_sample(
    samples: List,
    predictions: List,
    means: List,
    sigmas: List,
    recon_samples: List,
    show_points: bool = False,
):
    single_ecg_sample, _ = samples[0]

    ecgs = [ecg for ecg, _ in samples]
    ecgs = [item for sublist in ecgs for item in sublist]
    ecgs = [item for sublist in ecgs for item in sublist]

    recons = [recons for recons in recon_samples]
    recons = [item for sublist in recons for item in sublist]
    recons = [item for sublist in recons for item in sublist]

    labels = [
        np.ones(single_ecg_sample.shape[0]) * label.cpu().numpy()
        for _, label in samples
    ]
    labels = [item for sublist in labels for item in sublist]

    preds = [np.ones(single_ecg_sample.shape[0]) * pred for pred in predictions]
    preds = [item for sublist in preds for item in sublist]

    mus = []
    for sample_means in means:
        sample_mu = []
        for mean in sample_means:
            sample_mu.append(np.ones(len(single_ecg_sample.data[0])) * mean)
        mus.append(sample_mu)

    stds = []
    for sample_sigmas in sigmas:
        sample_std = []
        for sigma in sample_sigmas:
            sample_std.append(np.ones(len(single_ecg_sample.data[0])) * sigma)
        stds.append(sample_std)

    fig = make_subplots(rows=4, cols=1, shared_xaxes=True)

    if show_points:
        mode = "lines+markers"
    else:
        mode = "lines"

    marker = dict(color="#FFFFFF", size=5, line=dict(color="black", width=1))
    ecg = go.Scatter(
        x=[*range(len(ecgs))],
        y=ecgs,
        name=f"ECG",
        mode=mode,
        marker=marker,
        line=dict(color="red", width=2),
    )
    recon = go.Scatter(
        x=[*range(len(ecgs))],
        y=recons,
        name=f"Reconstruction",
        mode=mode,
        marker=marker,
        line=dict(color="green", width=2),
    )
    pred = go.Scatter(
        x=[*range(len(labels))],
        y=preds,
        name=f"Prediction",
        mode=mode,
        marker=marker,
        line=dict(color="blue", width=2),
    )
    label = go.Scatter(
        x=[*range(len(labels))],
        y=labels,
        name=f"Label",
        mode=mode,
        marker=marker,
        line=dict(color="orange", width=2),
    )

    fig.append_trace(ecg, row=1, col=1)
    fig.append_trace(recon, row=2, col=1)
    fig.append_trace(pred, row=3, col=1)
    fig.append_trace(label, row=4, col=1)

    # for distribution in np.arange(0, 8):
    #     dist_mean = [mu[distribution] for mu in mus]
    #     dist_mean = [item for sublist in dist_mean for item in sublist]

    #     dist_std = [std[distribution] for std in stds]
    #     dist_std = [item for sublist in dist_std for item in sublist]

    #     dist_mean_upper = [(mean + abs(std)) for mean, std in zip(dist_mean, dist_std)]
    #     dist_mean_lower = [(mean - abs(std)) for mean, std in zip(dist_mean, dist_std)]

    #     mean_trace = go.Scatter(
    #         x=[*range(len(labels))],
    #         y=dist_mean,
    #         name=f"Latent_Var_{distribution+1}",
    #         mode=mode,
    #         marker=marker,
    #         line=dict(color="black", width=2),
    #     )
    #     mean_trace_upper = go.Scatter(
    #         x=[*range(len(labels))],
    #         y=dist_mean_upper,
    #         name=f"Latent_Var_{distribution+1}_upper",
    #         mode=mode,
    #         marker=marker,
    #         line=dict(color="firebrick", width=2, dash="dash"),
    #     )
    #     mean_trace_lower = go.Scatter(
    #         x=[*range(len(labels))],
    #         y=dist_mean_lower,
    #         name=f"Latent_Var_{distribution+1}_lower",
    #         mode=mode,
    #         marker=marker,
    #         line=dict(color="darkgreen", width=2, dash="dash"),
    #     )
    #     fig.append_trace(mean_trace, row=3 + distribution + 1, col=1)
    #     fig.append_trace(mean_trace_upper, row=3 + distribution + 1, col=1)
    #     fig.append_trace(mean_trace_lower, row=3 + distribution + 1, col=1)

    # fig = fig.update_layout(width=4000, height=2200)

    return fig


def generate_long_data_sample(
    no_of_segments: int = 5, no_of_abnormal_segments: int = 2
):
    if no_of_abnormal_segments > no_of_segments:
        raise ValueError(
            "No of abnormal segments must be less than the total no of segments."
        )

    no_of_normal_segments = no_of_segments - no_of_abnormal_segments

    full_dataset = ECGDataset(data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg"))

    normal_dataset = Subset(full_dataset, full_dataset.normal_idx)
    abnormal_dataset = Subset(full_dataset, full_dataset.abnormal_idx)

    normal_indexes = random.sample(
        sorted(full_dataset.normal_idx), no_of_normal_segments
    )
    abnormal_indexes = random.sample(
        sorted(full_dataset.abnormal_idx), no_of_abnormal_segments
    )

    normal_samples = [
        full_dataset.__getitem__(normal_index) for normal_index in normal_indexes
    ]
    abnormal_samples = [
        full_dataset.__getitem__(abnormal_index) for abnormal_index in abnormal_indexes
    ]

    samples = normal_samples + abnormal_samples

    random.shuffle(samples)

    return samples


def get_prediction_for_ecg(ecg):
    ecg = ecg.view(1, 187, 1)
    mean, logvar, x_hat = model(ecg, [187])
    loss_dict = model.loss_function(x_hat, ecg, mean, logvar, 1)
    sigma = torch.abs(torch.exp(0.5 * logvar))
    x_hat = x_hat.view(187, 1)
    return loss_dict["recon_loss"], mean, sigma, x_hat


samples = generate_long_data_sample()
outputs = [get_prediction_for_ecg(sample) for sample, _ in samples]

recons = [recon.detach().cpu().numpy() for _, _, _, recon in outputs]

predictions = [recon_loss.detach().cpu().numpy() for recon_loss, _, _, _ in outputs]
means = [torch.squeeze(mean).detach().cpu().numpy() for _, mean, _, _ in outputs]
sigmas = [torch.squeeze(sigma).detach().cpu().numpy() for _, _, sigma, _ in outputs]

visualize_long_sample(
    samples=samples,
    predictions=predictions,
    means=means,
    sigmas=sigmas,
    recon_samples=recons,
)